#!/usr/bin/env python3
# Wine Vulkan generator
#
# Copyright 2017-2018 Roderick Colenbrander
# Copyright 2022 Jacek Caban for CodeWeavers
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
#  License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
#

import argparse
import logging
import os
import re
import urllib.request
import xml.etree.ElementTree as ET
from collections import OrderedDict
from collections.abc import Sequence
from enum import Enum

# This script generates code for a Wine Vulkan ICD driver from Vulkan's xr.xml.
# Generating the code is like 10x worse than OpenGL, which is mostly a calling
# convention passthrough.
#
# The script parses xr.xml and maps functions and types to helper objects. These
# helper objects simplify the xml parsing and map closely to the Vulkan types.
# The code generation utilizes the helper objects during code generation and
# most of the ugly work is carried out by these objects.
#
# Vulkan ICD challenges:
# - Vulkan ICD loader (openxr-1.dll) relies on a section at the start of
#   'dispatchable handles' (e.g. XrDevice, XrInstance) for it to insert
#   its private data. It uses this area to stare its own dispatch tables
#   for loader internal use. This means any dispatchable objects need wrapping.
#
# - Vulkan structures have different alignment between win32 and 32-bit Linux.
#   This means structures with alignment differences need conversion logic.
#   Often structures are nested, so the parent structure may not need any
#   conversion, but some child may need some.
#
# xr.xml parsing challenges:
# - Contains type data for all platforms (generic Vulkan, Windows, Linux,..).
#   Parsing of extension information required to pull in types and functions
#   we really want to generate. Just tying all the data together is tricky.
#
# - Extensions can affect core types e.g. add new enum values, bitflags or
#   additional structure chaining through 'next' / 'type'.
#
# - Arrays are used all over the place for parameters or for structure members.
#   Array length is often stored in a previous parameter or another structure
#   member and thus needs careful parsing.

LOGGER = logging.Logger("openxr")
LOGGER.addHandler(logging.StreamHandler())

XR_XML_VERSION = "1.1.51"
WINE_XR_VERSION = (1, 1)

# Filenames to create.
WINE_OPENXR_H = "./wineopenxr.h"
WINE_OPENXR_JSON = "./wineopenxr.json"
WINE_OPENXR_THUNKS_C = "./openxr_thunks.c"
WINE_OPENXR_THUNKS_H = "./openxr_thunks.h"
WINE_OPENXR_LOADER_THUNKS_C = "./loader_thunks.c"
WINE_OPENXR_LOADER_THUNKS_H = "./loader_thunks.h"

# Extension enum values start at a certain offset (EXT_BASE).
# Relative to the offset each extension has a block (EXT_BLOCK_SIZE)
# of values.
# Start for a given extension is:
# EXT_BASE + (extension_number-1) * EXT_BLOCK_SIZE
EXT_BASE = 1000000000
EXT_BLOCK_SIZE = 1000

UNSUPPORTED_EXTENSIONS = [
    # Handling of XR_EXT_debug_report requires some consideration. The win32
    # loader already provides it for us and it is somewhat usable. If we add
    # plumbing down to the native layer, we will get each message twice as we
    # use 2 loaders (win32+native), but we may get output from the driver.
    # In any case callback conversion is required.
    "XR_EXT_debug_utils",
    "XR_KHR_loader_init",
    "XR_MSFT_perception_anchor_interop",
    "XR_HTC_foveation",
]

# Functions part of our wineopenxr graphics driver interface.
# DRIVER_VERSION should be bumped on any change to driver interface
# in FUNCTION_OVERRIDES
DRIVER_VERSION = 1

# Table of functions for which we have a special implementation.
# These are regular device / instance functions for which we need
# to do more work compared to a regular thunk or because they are
# part of the driver interface.
# - dispatch (default: True):  set whether we need a function pointer in the device / instance dispatch table.
FUNCTION_OVERRIDES = {
    # Global functions
    "xrCreateInstance" : {"dispatch" : False},

    "xrCreateSession" : {"dispatch" : True},

    "xrGetInstanceProcAddr" : {"dispatch" : False},
    "xrEnumerateInstanceExtensionProperties" : {"dispatch" : False},

    "xrConvertTimeToWin32PerformanceCounterKHR" : {"dispatch" : False},
    "xrConvertWin32PerformanceCounterToTimeKHR" : {"dispatch" : False},
    "xrGetD3D11GraphicsRequirementsKHR" : {"dispatch" : False},
    "xrGetD3D12GraphicsRequirementsKHR" : {"dispatch" : False},

    "xrGetVulkanGraphicsDeviceKHR" : {"dispatch" : True},
    "xrGetVulkanGraphicsDevice2KHR" : {"dispatch" : True},
    "xrGetVulkanDeviceExtensionsKHR" : {"dispatch" : True},
    "xrGetVulkanInstanceExtensionsKHR" : {"dispatch" : True},

    "xrCreateSwapchain" : {"dispatch" : True},
}

# functions for which a user driver entry must be generated
USER_DRIVER_FUNCS = {}

# functions for which the unix thunk is manually implemented
MANUAL_UNIX_THUNKS = {
    "xrCreateInstance",
    "xrCreateSession",
    "xrCreateSwapchain",
    "xrGetInstanceProcAddr",
    "xrEnumerateInstanceExtensionProperties",
    "xrConvertTimeToWin32PerformanceCounterKHR",
    "xrConvertWin32PerformanceCounterToTimeKHR",
    "xrGetD3D11GraphicsRequirementsKHR",
    "xrGetD3D12GraphicsRequirementsKHR",
    "xrGetVulkanGraphicsDeviceKHR",
    "xrGetVulkanGraphicsDevice2KHR",
    "xrGetVulkanInstanceExtensionsKHR",
}

# loader functions which are entirely manually implemented
MANUAL_LOADER_FUNCTIONS = {
    "xrConvertTimeToWin32PerformanceCounterKHR",
    "xrConvertWin32PerformanceCounterToTimeKHR",
    "xrGetD3D11GraphicsRequirementsKHR",
    "xrGetD3D12GraphicsRequirementsKHR",
    "xrCreateApiLayerInstance",
    "xrGetInstanceProcAddr",
    "xrNegotiateLoaderRuntimeInterface",
    "xrNegotiateLoaderApiLayerInterface",
    "xrCreateVulkanInstanceKHR",
    "xrCreateVulkanDeviceKHR"
}

# functions which loader thunks are manually implemented
MANUAL_LOADER_THUNKS = {
    "xrCreateInstance",
    "xrDestroyInstance",
    "xrCreateSession",
    "xrDestroySession",
    "xrPollEvent",
    "xrGetSystem",
    "xrEnumerateSwapchainFormats",
    "xrCreateSwapchain",
    "xrDestroySwapchain",
    "xrEnumerateSwapchainImages",
    "xrAcquireSwapchainImage",
    "xrReleaseSwapchainImage",
    "xrBeginFrame",
    "xrEndFrame",
    "xrGetVulkanDeviceExtensionsKHR",
}

STRUCT_CHAIN_CONVERSIONS = {
    "XrInstanceCreateInfo": [],
}

UNEXPOSED_EXTENSIONS = {}
STRUCT_COPY = {};

# Some struct members are conditionally ignored and callers are free to leave them uninitialized.
# We can't deduce that from XML, so we allow expressing it here.
MEMBER_LENGTH_EXPRESSIONS = {}

PERF_CRITICAL_FUNCTIONS = []

ALLOWED_PROTECTS = [
        "XR_USE_PLATFORM_WIN32",
        "XR_USE_GRAPHICS_API_VULKAN",
        "XR_USE_GRAPHICS_API_OPENGL",
        "XR_USE_GRAPHICS_API_D3D11",
        "XR_USE_GRAPHICS_API_D3D12",
]
NOT_OUR_FUNCTIONS = [
    # xr.xml defines that as a part of XR_LOADER_VERSION_1_0 commands but it looks like only layers should provide it
    # (through dll export).
    "xrNegotiateLoaderApiLayerInterface",
]

class Direction(Enum):
    """ Parameter direction: input, output, input_output. """
    INPUT = 1
    OUTPUT = 2


class Unwrap(Enum):
    NONE = 0
    HOST = 1
    DRIVER = 2


def api_is_openxr(obj):
    return "openxr" in obj.get("api", "openxr").split(",")


def convert_suffix(direction, win_type, unwrap, is_wrapped):
    if direction == Direction.OUTPUT:
        if not is_wrapped:
            return "host_to_{0}".format(win_type)
        if unwrap == Unwrap.NONE:
            return "unwrapped_host_to_{0}".format(win_type)
        if unwrap == Unwrap.DRIVER:
            return "driver_to_{0}".format(win_type)
        return "host_to_{0}".format(win_type)
    else:
        if not is_wrapped:
            return "{0}_to_host".format(win_type)
        if unwrap == Unwrap.NONE:
            return "{0}_to_unwrapped_host".format(win_type)
        if unwrap == Unwrap.DRIVER:
            return "{0}_to_driver".format(win_type)
        return "{0}_to_host".format(win_type)


class XrBaseType(object):
    def __init__(self, name, _type, alias=None, requires=None):
        """ Vulkan base type class.

        XrBaseType is mostly used by Vulkan to define its own
        base types like XrFlags through typedef out of e.g. uint32_t.

        Args:
            name (:obj:'str'): Name of the base type.
            _type (:obj:'str'): Underlying type
            alias (bool): type is an alias or not.
            requires (:obj:'str', optional): Other types required.
                Often bitmask values pull in a *FlagBits type.
        """
        self.name = name
        self.type = _type
        self.alias = alias
        self.requires = requires
        self.required = False

    def definition(self):
        # Definition is similar for alias or non-alias as type
        # is already set to alias.
        if not self.type is None:
            return "typedef {0} {1};\n".format(self.type, self.name)
        else:
            return "struct {0};\n".format(self.name)

    def is_alias(self):
        return bool(self.alias)


class XrConstant(object):
    def __init__(self, name, value):
        self.name = name
        self.value = value

    def definition(self):
        text = "#define {0} {1}\n".format(self.name, self.value)
        return text


class XrDefine(object):
    def __init__(self, name, value):
        self.name = name
        self.value = value

    @staticmethod
    def from_xml(define):
        if not api_is_openxr(define):
            return None

        name_elem = define.find("name")

        if name_elem is None:
            # <type category="define" name="some_name">some_value</type>
            name = define.attrib.get("name")

            # We override behavior of XR_USE_64_BIT_PTR_DEFINES as the default non-dispatchable handle
            # definition various between 64-bit (uses pointers) and 32-bit (uses uint64_t).
            # This complicates TRACEs in the thunks, so just use uint64_t.
            if name == "XR_USE_64_BIT_PTR_DEFINES":
                value = "#define XR_USE_64_BIT_PTR_DEFINES 0"
            else:
                value = define.text
            return XrDefine(name, value)

        # With a name element the structure is like:
        # <type category="define"><name>some_name</name>some_value</type>
        name = name_elem.text

        # Perform minimal parsing for Vulkan constants, which we don't need, but are referenced
        # elsewhere in xr.xml.
        # - XR_API_VERSION is a messy, deprecated constant and we don't want generate code for it.
        # - AHardwareBuffer/ANativeWindow are forward declarations for Android types, which leaked
        #   into the define region.
        if name in ["XR_API_VERSION", "AHardwareBuffer", "ANativeWindow", "CAMetalLayer"]:
            return XrDefine(name, None)

        # The body of the define is basically unstructured C code. It is not meant for easy parsing.
        # Some lines contain deprecated values or comments, which we try to filter out.
        value = ""
        for line in define.text.splitlines():
            value += "\n"
            # Skip comments or deprecated values.
            if "//" in line:
                continue
            value += line

        for child in define:
            value += child.text
            if child.tail is not None:
                # Split comments for XR_API_VERSION_1_0 / XR_API_VERSION_1_1
                if "//" in child.tail:
                    value += child.tail.split("//")[0]
                else:
                    value += child.tail

        return XrDefine(name, value.rstrip(' '))

    def definition(self):
        if self.value is None:
            return ""

        # Nothing to do as the value was already put in the right form during parsing.
        return "{0}\n".format(self.value)

    def is_alias(self):
        return False

class XrEnum(object):
    def __init__(self, name, bitwidth, alias=None):
        if not bitwidth in [32, 64]:
            LOGGER.error("unknown bitwidth {0} for {1}".format(bitwidth, name))
        self.name = name
        self.bitwidth = bitwidth
        self.values = [] if alias == None else alias.values
        self.required = False
        self.alias = alias
        self.aliased_by = []

    @staticmethod
    def from_alias(enum, alias):
        name = enum.attrib.get("name")
        aliasee = XrEnum(name, alias.bitwidth, alias=alias)

        alias.add_aliased_by(aliasee)
        return aliasee

    @staticmethod
    def from_xml(enum):
        if not api_is_openxr(enum):
            return None

        name = enum.attrib.get("name")
        bitwidth = int(enum.attrib.get("bitwidth", "32"))
        result = XrEnum(name, bitwidth)

        for v in enum.findall("enum"):
            value_name = v.attrib.get("name")
            # Value is either a value or a bitpos, only one can exist.
            value = v.attrib.get("value")
            alias_name = v.attrib.get("alias")
            if alias_name:
                result.create_alias(value_name, alias_name)
            elif value:
                result.create_value(value_name, value)
            else:
                # bitmask
                result.create_bitpos(value_name, int(v.attrib.get("bitpos")))

        if bitwidth == 32:
            # openxr.h contains a *_MAX_ENUM value set to 32-bit at the time of writing,
            # which is to prepare for extensions as they can add values and hence affect
            # the size definition.
            max_name = re.sub(r'([0-9a-z_])([A-Z0-9])',r'\1_\2', name).upper() + "_MAX_ENUM"
            result.create_value(max_name, "0x7fffffff")

        return result

    def create_alias(self, name, alias_name):
        """ Create an aliased value for this enum """
        self.add(XrEnumValue(name, self.bitwidth, alias=alias_name))

    def create_value(self, name, value):
        """ Create a new value for this enum """
        # Some values are in hex form. We want to preserve the hex representation
        # at least when we convert back to a string. Internally we want to use int.
        hex = "0x" in value
        self.add(XrEnumValue(name, self.bitwidth, value=int(value, 0), hex=hex))

    def create_bitpos(self, name, pos):
        """ Create a new bitmask value for this enum """
        self.add(XrEnumValue(name, self.bitwidth, value=(1 << pos), hex=True))

    def add(self, value):
        """ Add a value to enum. """

        # Extensions can add new enum values. When an extension is promoted to Core
        # the registry defines the value twice once for old extension and once for
        # new Core features. Add the duplicate if it's explicitly marked as an
        # alias, otherwise ignore it.
        for v in self.values:
            if not value.is_alias() and v.value == value.value:
                LOGGER.debug("Adding duplicate enum value {0} to {1}".format(v, self.name))
                return
        # Avoid adding duplicate aliases multiple times
        if not any(x.name == value.name for x in self.values):
            self.values.append(value)

    def fixup_64bit_aliases(self):
        """ Replace 64bit aliases with literal values """
        # Older GCC versions need a literal to initialize a static const uint64_t
        # which is what we use for 64bit bitmasks.
        if self.bitwidth != 64:
            return
        for value in self.values:
            if not value.is_alias():
                continue
            alias = next(x for x in self.values if x.name == value.alias)
            value.hex = alias.hex
            value.value = alias.value

    def definition(self):
        if self.is_alias():
            return ""

        default_value = 0x7ffffffe if self.bitwidth == 32 else 0xfffffffffffffffe

        # Print values sorted, values can have been added in a random order.
        values = sorted(self.values, key=lambda value: value.value if value.value is not None else default_value)

        if self.bitwidth == 32:
            text = "typedef enum {0}\n{{\n".format(self.name)
            for value in values:
                text += "    {0},\n".format(value.definition())
            text += "}} {0};\n".format(self.name)
        elif self.bitwidth == 64:
            text = "typedef XrFlags64 {0};\n\n".format(self.name)
            for value in values:
                text += "static const {0} {1};\n".format(self.name, value.definition())

        for aliasee in self.aliased_by:
            text += "typedef {0} {1};\n".format(self.name, aliasee.name)

        text += "\n"
        return text

    def is_alias(self):
        return bool(self.alias)

    def add_aliased_by(self, aliasee):
        self.aliased_by.append(aliasee)


class XrEnumValue(object):
    def __init__(self, name, bitwidth, value=None, hex=False, alias=None):
        self.name = name
        self.bitwidth = bitwidth
        self.value = value
        self.hex = hex
        self.alias = alias

    def __repr__(self):
        postfix = "ull" if self.bitwidth == 64 else ""
        if self.is_alias() and self.value == None:
            return "{0}={1}".format(self.name, self.alias)
        return "{0}={1}{2}".format(self.name, self.value, postfix)

    def definition(self):
        """ Convert to text definition e.g. XR_FOO = 1 """
        postfix = "ull" if self.bitwidth == 64 else ""
        if self.is_alias() and self.value == None:
            return "{0} = {1}".format(self.name, self.alias)

        # Hex is commonly used for FlagBits and sometimes within
        # a non-FlagBits enum for a bitmask value as well.
        if self.hex:
            return "{0} = 0x{1:08x}{2}".format(self.name, self.value, postfix)
        else:
            return "{0} = {1}{2}".format(self.name, self.value, postfix)

    def is_alias(self):
        return self.alias is not None


class XrFunction(object):
    def __init__(self, _type=None, name=None, params=[], alias=None):
        self.extensions = set()
        self.name = name
        self.type = _type
        self.params = params
        self.alias = alias

        # For some functions we need some extra metadata from FUNCTION_OVERRIDES.
        func_info = FUNCTION_OVERRIDES.get(self.name, {})
        self.dispatch = func_info.get("dispatch", True)
        self.extra_param = func_info.get("extra_param", None)

        # Required is set while parsing which APIs and types are required
        # and is used by the code generation.
        self.required = True if func_info else False

        if self.name in MANUAL_UNIX_THUNKS:
            self.unwrap = Unwrap.NONE
        elif self.name in USER_DRIVER_FUNCS:
            self.unwrap = Unwrap.DRIVER
        else:
            self.unwrap = Unwrap.HOST

    @staticmethod
    def from_alias(command, alias):
        """ Create XrFunction from an alias command.

        Args:
            command: xml data for command
            alias (XrFunction): function to use as a base for types / parameters.

        Returns:
            XrFunction
        """
        if not api_is_openxr(command):
            return None

        func_name = command.attrib.get("name")
        func_type = alias.type
        params = alias.params

        return XrFunction(_type=func_type, name=func_name, params=params, alias=alias)

    @staticmethod
    def from_xml(command, types):
        if not api_is_openxr(command):
            return None

        proto = command.find("proto")
        func_name = proto.find("name").text
        func_type = proto.find("type").text

        params = []
        for param in command.findall("param"):
            xr_param = XrParam.from_xml(param, types, params)
            if xr_param:
                params.append(xr_param)

        return XrFunction(_type=func_type, name=func_name, params=params)

    def get_conversions(self):
        """ Get a list of conversion functions required for this function if any.
        Parameters which are structures may require conversion between win32
        and the host platform. This function returns a list of conversions
        required.
        """

        conversions = []
        for param in self.params:
            conversions.extend(param.get_conversions(self.unwrap))
        return conversions

    def is_alias(self):
        return bool(self.alias)

    def is_core_func(self):
        """ Returns whether the function is a Vulkan core function.
        Core functions are APIs defined by the Vulkan spec to be part of the
        Core API as well as several KHR WSI extensions.
        """

        if not self.extensions:
            return True

        return any(ext in self.extensions for ext in CORE_EXTENSIONS)

    def is_device_func(self):
        # If none of the other, it must be a device function
        return not self.is_global_func() and not self.is_instance_func() and not self.is_phys_dev_func()

    def is_driver_func(self):
        """ Returns if function is part of Wine driver interface. """
        return self.name in USER_DRIVER_FUNCS

    def is_global_func(self):
        # Treat xrGetInstanceProcAddr as a global function as it
        # can operate with NULL for xrInstance.
        if self.name == "xrGetInstanceProcAddr":
            return True
        # Global functions are not passed a dispatchable object.
        elif self.params[0].is_dispatchable():
            return False
        return True

    def is_instance_func(self):
        # Instance functions are passed XrInstance.
        if self.params[0].type == "XrInstance":
            return True
        return False

    def is_phys_dev_func(self):
        # Physical device functions are passed XrPhysicalDevice.
        if self.params[0].type == "XrPhysicalDevice":
            return True
        return False

    def is_required(self):
        return self.required

    def returns_longlong(self):
        return self.type in ["uint64_t", "XrDeviceAddress"]

    def needs_dispatch(self):
        return self.dispatch

    def needs_private_thunk(self):
        return self.needs_exposing() and self.name not in MANUAL_LOADER_FUNCTIONS and \
            self.name in MANUAL_UNIX_THUNKS

    def needs_exposing(self):
        # The function needs exposed if at-least one extension isn't both UNSUPPORTED and UNEXPOSED
        return self.is_required() and (not self.extensions or not self.extensions.issubset(UNEXPOSED_EXTENSIONS))

    def is_perf_critical(self):
        # xrCmd* functions are frequently called, do not trace for performance
        if self.name.startswith("xrCmd") and self.type == "void":
            return True
        return self.name in PERF_CRITICAL_FUNCTIONS

    def pfn(self, prefix="p", call_conv=None):
        """ Create function pointer. """

        if call_conv:
            pfn = "{0} ({1} *{2}_{3})(".format(self.type, call_conv, prefix, self.name)
        else:
            pfn = "{0} (*{1}_{2})(".format(self.type, prefix, self.name)

        for i, param in enumerate(self.params):
            if param.const:
                pfn += param.const + " "

            pfn += param.type

            if param.is_pointer():
                pfn += " " + param.pointer

            if param.array_len is not None:
                pfn += "[{0}]".format(param.array_len)

            if i < len(self.params) - 1:
                pfn += ", "
        pfn += ")"
        return pfn

    def prototype(self, call_conv=None, prefix=None, is_thunk=False):
        """ Generate prototype for given function.

        Args:
            call_conv (str, optional): calling convention e.g. WINAPI
            prefix (str, optional): prefix to append prior to function name e.g. xrFoo -> wine_xrFoo
        """

        proto = "{0}".format(self.type)

        if call_conv is not None:
            proto += " {0}".format(call_conv)

        if prefix is not None:
            proto += " {0}{1}(".format(prefix, self.name)
        else:
            proto += " {0}(".format(self.name)

        # Add all the parameters.
        proto += ", ".join([p.definition() for p in self.params])

        if is_thunk and self.extra_param:
            extra_param_is_new = True
            for p in self.params:
                if p.name == self.extra_param:
                    extra_param_is_new = False
            if extra_param_is_new:
                proto += ", void *" + self.extra_param
            else:
                proto += ", void *win_" + self.extra_param

        proto += ")"
        return proto

    def loader_body(self):
        body = "    struct {0}_params params;\n".format(self.name)
        if not self.is_perf_critical():
            body += "    NTSTATUS _status;\n"
        for p in self.params:
            body += "    params.{0} = {0};\n".format(p.name)

        # Call the Unix function.
        if self.is_perf_critical():
            body += "    UNIX_CALL({0}, &params);\n".format(self.name)
        else:
            body += "    _status = UNIX_CALL({0}, &params);\n".format(self.name)
            body += "    assert(!_status && \"{0}\");\n".format(self.name)

        if self.type != "void":
            body += "    return params.result;\n"
        return body

    def body(self, conv, params_prefix=""):
        body = ""
        needs_alloc = False
        deferred_op = None

        # Declare any tmp parameters for conversion.
        for p in self.params:
            if p.needs_variable(conv, self.unwrap):
                if p.is_dynamic_array():
                    body += "    {2}{0} *{1}_host;\n".format(
                        p.type, p.name, "const " if p.is_const() else "")
                elif p.optional:
                    body += "    {0} *{1}_host = NULL;\n".format(p.type, p.name)
                    needs_alloc = True
                else:
                    body += "    {0} {1}_host;\n".format(p.type, p.name)
            if p.needs_alloc(conv, self.unwrap):
                needs_alloc = True
            if p.type == "XrDeferredOperationKHR" and not p.is_pointer():
                deferred_op = p.name

        if needs_alloc:
            body += "    struct conversion_context local_ctx;\n"
            body += "    struct conversion_context *ctx = &local_ctx;\n"
        body += "\n"

        if not self.is_perf_critical():
            body += "    {0}\n".format(self.trace(params_prefix=params_prefix, conv=conv))

        if self.params[0].optional and self.params[0].is_handle():
            if self.type != "void":
                LOGGER.warning("return type {0} with optional handle not supported".format(self.type))
            body += "    if (!{0}{1})\n".format(params_prefix, self.params[0].name)
            body += "        return STATUS_SUCCESS;\n\n"

        if needs_alloc:
            if deferred_op is not None:
                body += "    if (params->{} == XR_NULL_HANDLE)\n".format(deferred_op)
                body += "    "
            body += "    init_conversion_context(ctx);\n"
            if deferred_op is not None:
                body += "    else\n"
                body += "        ctx = &wine_deferred_operation_from_handle(params->{})->ctx;\n".format(deferred_op)

        # Call any win_to_host conversion calls.
        unwrap = self.name not in MANUAL_UNIX_THUNKS
        for p in self.params:
            if p.needs_conversion(conv, self.unwrap, Direction.INPUT):
                body += p.copy(Direction.INPUT, conv, self.unwrap, prefix=params_prefix)
            elif p.is_dynamic_array() and p.needs_conversion(conv, self.unwrap, Direction.OUTPUT):
                body += "    {0}_host = ({2}{0} && {1}) ? conversion_context_alloc(ctx, sizeof(*{0}_host) * {1}) : NULL;\n".format(
                    p.name, p.get_dyn_array_len(params_prefix, conv), params_prefix)

        # Build list of parameters containing converted and non-converted parameters.
        # The param itself knows if conversion is needed and applies it when we set conv=True.
        unwrap = Unwrap.NONE if self.name in MANUAL_UNIX_THUNKS else self.unwrap
        params = ", ".join([p.variable(conv, unwrap, params_prefix) for p in self.params])
        if self.extra_param:
            if conv:
                params += ", UlongToPtr({0}{1})".format(params_prefix, self.extra_param)
            else:
                params += ", (void *){0}{1}".format(params_prefix, self.extra_param)

        if self.name not in MANUAL_UNIX_THUNKS:
            func_prefix = "g_xr_host_instance_dispatch_table.p_"
        else:
            func_prefix = "wine_"

        # Call the host Vulkan function.
        if self.type == "void":
            body += "    {0}{1}({2});\n".format(func_prefix, self.name, params)
        else:
            body += "    {0}result = {1}{2}({3});\n".format(params_prefix, func_prefix, self.name, params)

        # Call any host_to_win conversion calls.
        for p in self.params:
            if p.needs_conversion(conv, self.unwrap, Direction.OUTPUT):
                body += p.copy(Direction.OUTPUT, conv, self.unwrap, prefix=params_prefix)

        if needs_alloc:
            if deferred_op is not None:
                body += "    if (params->{} == XR_NULL_HANDLE)\n".format(deferred_op)
                body += "    "
            body += "    free_conversion_context(ctx);\n"

        # Finally return the result. Performance critical functions return void to allow tail calls.
        if not self.is_perf_critical():
            body += "    return STATUS_SUCCESS;\n"

        return body

    def spec(self, prefix=None, symbol=None):
        """ Generate spec file entry for this function.

        Args
            prefix (str, optional): prefix to prepend to entry point name.
            symbol (str, optional): allows overriding the name of the function implementing the entry point.
        """

        spec = ""
        params = " ".join([p.spec() for p in self.params])
        if prefix is not None:
            spec += "@ stdcall -private {0}{1}({2})".format(prefix, self.name, params)
        else:
            spec += "@ stdcall {0}({1})".format(self.name, params)

        if symbol is not None:
            spec += " " + symbol

        spec += "\n"
        return spec

    def stub(self, call_conv=None, prefix=None):
        stub = self.prototype(call_conv=call_conv, prefix=prefix)
        stub += "\n{\n"
        stub += "    {0}".format(self.trace(message="stub: ", trace_func="FIXME"))

        if self.type == "XrResult":
            stub += "    return XR_ERROR_OUT_OF_HOST_MEMORY;\n"
        elif self.type == "XrBool32":
            stub += "    return XR_FALSE;\n"
        elif self.type == "PFN_xrVoidFunction":
            stub += "    return NULL;\n"

        stub += "}\n\n"
        return stub

    def thunk(self, prefix=None, conv=False):
        thunk = ""
        if not conv:
            thunk += "#ifdef _WIN64\n"
        if self.is_perf_critical():
            thunk += "static void {0}{1}(void *args)\n".format(prefix, self.name)
        else:
            thunk += "static NTSTATUS {0}{1}(void *args)\n".format(prefix, self.name)
        thunk += "{\n"
        if conv:
            thunk += "    struct\n"
            thunk += "    {\n"
            extra_param_is_new = True
            for p in self.params:
                thunk += "        {0};\n".format(p.definition(conv=True, is_member=True))
                if p.name == self.extra_param:
                    extra_param_is_new = False
            if self.extra_param and extra_param_is_new:
                thunk += "        PTR32 {0};\n".format(self.extra_param)
            if self.type != "void":
                thunk += "        {0} result;\n".format(self.type)
            thunk += "    } *params = args;\n"
        else:
            thunk += "    struct {0}_params *params = args;\n".format(self.name)
        thunk += self.body(conv, params_prefix="params->")
        thunk += "}\n"
        if not conv:
            thunk += "#endif /* _WIN64 */\n"
        thunk += "\n"
        return thunk

    def loader_thunk(self, prefix=None):
        thunk = self.prototype(call_conv="WINAPI", prefix=prefix)
        thunk += "\n{\n"
        thunk += self.loader_body()
        thunk += "}\n\n"
        return thunk

    def trace(self, message=None, trace_func=None, params_prefix="", conv=False):
        """ Create a trace string including all parameters.

        Args:
            message (str, optional): text to print at start of trace message e.g. 'stub: '
            trace_func (str, optional): used to override trace function e.g. FIXME, printf, etcetera.
        """
        if trace_func is not None:
            trace = "{0}(\"".format(trace_func)
        else:
            trace = "TRACE(\""

        if message is not None:
            trace += message

        # First loop is for all the format strings.
        trace += ", ".join([p.format_string(conv) for p in self.params])
        trace += "\\n\""

        # Second loop for parameter names and optional conversions.
        for param in self.params:
            if param.format_conv is not None:
                trace += ", " + param.format_conv.format("{0}{1}".format(params_prefix, param.name))
            else:
                trace += ", {0}{1}".format(params_prefix, param.name)
        trace += ");\n"

        return trace


class XrFunctionPointer(object):
    def __init__(self, _type, name, members, forward_decls, params_text):
        self.name = name
        self.members = members
        self.type = _type
        self.required = False
        self.forward_decls = forward_decls
        self.params_text = params_text

    @staticmethod
    def from_xml(funcpointer):
        members = []
        begin = None

        for t in funcpointer.findall("type"):
            # General form:
            # <type>void</type>*       pUserData,
            # Parsing of the tail (anything past </type>) is tricky since there
            # can be other data on the next line like: const <type>int</type>..

            const = True if begin and "const" in begin else False
            _type = t.text
            lines = t.tail.split(",\n")
            if lines[0][0] == "*":
                pointer = "*"
                name = lines[0][1:].strip()
            else:
                pointer = None
                name = lines[0].strip()

            # Filter out ); if it is contained.
            name = name.partition(");")[0]

            # If tail encompasses multiple lines, assign the second line to begin
            # for the next line.
            try:
                begin = lines[1].strip()
            except IndexError:
                begin = None

            members.append(XrMember(const=const, _type=_type, pointer=pointer, name=name))

        _type = funcpointer.text
        name = funcpointer.find("name").text
        if "requires" in funcpointer.attrib:
            forward_decls = funcpointer.attrib.get("requires").split(",")
        else:
            forward_decls = []
        params_text = None
        if members == []:
            index = 0
            for elem_part in funcpointer.itertext():
                index = index + 1
                if index == 3:
                    p = re.findall(r'\([^)]*\)', elem_part)
                    params_text = next(iter(p), None)
                    if params_text:
                        params_text = params_text[1:-1]
                    break
            
        if name == "PFN_xrCreateApiLayerInstance":
            forward_decls += ["XrInstanceCreateInfo"]


        return XrFunctionPointer(_type, name, members, forward_decls, params_text)

    def definition(self):
        text = ""
        # forward declare required structs
        for decl in self.forward_decls:
            text += "typedef struct {0} {0};\n".format(decl)

        text += "{0} {1})(\n".format(self.type, self.name)

        first = True
        if len(self.members) > 0:
            for m in self.members:
                if first:
                    text += "    " + m.definition()
                    first = False
                else:
                    text += ",\n    " + m.definition()
        elif self.params_text is not None:
            text += self.params_text
        else:
            # Just make the compiler happy by adding a void parameter.
            text += "void"
        text += ");\n"
        return text

    def is_alias(self):
        return False

class XrHandle(object):
    def __init__(self, name, _type, parent, alias=None):
        self.name = name
        self.type = _type
        self.parent = parent
        self.alias = alias
        self.required = False
        self.object_type = None

    @staticmethod
    def from_alias(handle, alias):
        name = handle.attrib.get("name")
        return XrHandle(name, alias.type, alias.parent, alias=alias)

    @staticmethod
    def from_xml(handle):
        if not api_is_openxr(handle):
            return None

        name = handle.find("name").text
        _type = handle.find("type").text
        parent = handle.attrib.get("parent") # Most objects have a parent e.g. XrQueue has XrDevice.
        return XrHandle(name, _type, parent)

    def definition(self):
        """ Generates handle definition e.g. XR_DEFINE_HANDLE(xrInstance) """

        # Legacy types are typedef'ed to the new type if they are aliases.
        if self.is_alias():
            return "typedef {0} {1};\n".format(self.alias.name, self.name)

        return "{0}({1})\n".format(self.type, self.name)

    def is_alias(self):
        return self.alias is not None

    def is_dispatchable(self):
        """ Some handles like XrInstance, XrDevice are dispatchable objects,
        which means they contain a dispatch table of function pointers.
        """
        return self.type == "XR_DEFINE_HANDLE"

    def is_required(self):
        return self.required

    def host_handle(self, name):
        """ Provide access to the host handle of a wrapped object. """
        if self.name == "XrInstance":
            return "wine_instance_from_handle({0})->host_instance".format(name)
        if self.name == "XrSession":
            return "wine_session_from_handle({0})->host_session".format(name)
        if self.name == "XrSwapchain":
            return "wine_swapchain_from_handle({0})->host_swapchain".format(name)

        return None

    def driver_handle(self, name):
        """ Provide access to the handle that should be passed to the wine driver """
        return self.host_handle(name)

    def unwrap_handle(self, name, unwrap):
        if unwrap == Unwrap.DRIVER:
            return self.driver_handle(name)
        if unwrap == Unwrap.HOST:
            return self.host_handle(name)
        if unwrap == Unwrap.NONE:
            return name
        return None

    def is_wrapped(self):
        return self.host_handle("test") is not None


class XrVariable(object):
    def __init__(self, const=False, type_info=None, type=None, name=None, pointer=None, array_len=None,
                 dyn_array_len=None, object_type=None, optional=False, returnedonly=False, parent=None,
                 selection=None, selector=None):
        self.const = const
        self.type_info = type_info
        self.type = type
        self.name = name
        self.parent = parent
        self.object_type = object_type
        self.optional = optional
        self.returnedonly = returnedonly
        self.selection = selection
        self.selector = selector

        self.pointer = pointer
        self.array_len = array_len
        self.dyn_array_len = dyn_array_len
        self.pointer_array = False
        if isinstance(dyn_array_len, str):
            i = dyn_array_len.find(",")
            if i != -1:
                self.dyn_array_len = dyn_array_len[0:i]
                self.pointer_array = True

        if type_info:
            self.set_type_info(type_info)

    def __eq__(self, other):
        """ Compare member based on name against a string. """
        return self.name == other

    def set_type_info(self, type_info):
        """ Helper function to set type information from the type registry.
        This is needed, because not all type data is available at time of
        parsing.
        """
        self.type_info = type_info
        self.handle = type_info["data"] if type_info["category"] == "handle" else None
        self.struct = type_info["data"] if type_info["category"] == "struct" or type_info["category"] == "union" else None

    def get_dyn_array_len(self, prefix, conv):
        if isinstance(self.dyn_array_len, int):
            return self.dyn_array_len

        len_str = self.dyn_array_len
        parent = self.parent
        len = prefix

        # check if length is a member of another struct (for example pAllocateInfo->commandBufferCount)
        i = len_str.find("->")
        if i != -1:
            var = parent[parent.index(len_str[0:i])]
            len_str = len_str[i+2:]
            len = "({0})->".format(var.value(len, conv))
            parent = var.struct

        if len_str in parent:
            var = parent[parent.index(len_str)]
            len = var.value(len, conv);
            if var.is_pointer():
                len = "*" + len
        else:
            len += len_str

        if isinstance(self.parent, XrStruct) and self.parent.name in MEMBER_LENGTH_EXPRESSIONS:
            exprs = MEMBER_LENGTH_EXPRESSIONS[self.parent.name]
            if self.name in exprs:
                len = exprs[self.name].format(struct=prefix, len=len)

        return len

    def is_const(self):
        return self.const

    def is_pointer(self):
        return self.pointer is not None

    def is_pointer_size(self):
        if self.type in ["size_t", "HWND", "HINSTANCE"]:
            return True
        if self.is_handle() and self.handle.is_dispatchable():
            return True
        return False

    def is_handle(self):
        return self.handle is not None

    def is_struct(self):
        return self.type_info["category"] == "struct"

    def is_union(self):
        return self.type_info["category"] == "union"

    def is_bitmask(self):
        return self.type_info["category"] == "bitmask"

    def is_enum(self):
        return self.type_info["category"] == "enum"

    def is_dynamic_array(self):
        """ Returns if the member is an array element.
        Vulkan uses this for dynamically sized arrays for which
        there is a 'count' parameter.
        """
        return self.dyn_array_len is not None and self.array_len is None

    def is_static_array(self):
        """ Returns if the member is an array.
        Vulkan uses this often for fixed size arrays in which the
        length is part of the member.
        """
        return self.array_len is not None

    def is_generic_handle(self):
        """ Returns True if the member is a unit64_t containing
        a handle with a separate object type
        """
        return self.object_type != None and self.type == "uint64_t"

    def needs_alignment(self):
        """ Check if this member needs alignment for 64-bit data.
        Various structures need alignment on 64-bit variables due
        to compiler differences on 32-bit between Win32 and Linux.
        """

        if self.is_pointer():
            return False
        elif self.type == "size_t":
            return False
        elif self.type in ["uint64_t", "XrDeviceAddress", "XrDeviceSize"]:
            return True
        elif self.is_bitmask():
            return self.type_info["data"].type == "XrFlags64"
        elif self.is_enum():
            return self.type_info["data"].bitwidth == 64
        elif self.is_struct() or self.is_union():
            return self.type_info["data"].needs_alignment()
        elif self.is_handle():
            # Dispatchable handles are pointers to objects, while
            # non-dispatchable are uint64_t and hence need alignment.
            return not self.handle.is_dispatchable()
        return False

    def is_wrapped(self):
        """ Returns if variable needs unwrapping of handle. """

        if self.is_struct():
            return self.struct.is_wrapped()

        if self.is_handle():
            return self.handle.is_wrapped()

        if self.is_generic_handle():
            return True

        return False

    def needs_alloc(self, conv, unwrap):
        """ Returns True if conversion needs allocation """
        if self.is_dynamic_array():
            return self.needs_conversion(conv, unwrap, Direction.INPUT, False) \
                or self.needs_conversion(conv, unwrap, Direction.OUTPUT, False)

        return (self.is_struct() or (self.is_union() and self.selector)) and self.struct.needs_alloc(conv, unwrap)

    def needs_win32_type(self):
        return (self.is_struct() or (self.is_union() and self.selector)) and self.struct.needs_win32_type()

    def get_conversions(self, unwrap, parent_const=False):
        """ Get a list of conversions required for this parameter if any.
        Parameters which are structures may require conversion between win32
        and the host platform. This function returns a list of conversions
        required.
        """

        conversions = []

        # Collect any member conversions first, so we can guarantee
        # those functions will be defined prior to usage by the
        # 'parent' param requiring conversion.
        if self.is_struct() or (self.is_union() and self.selector):
            struct = self.struct
            is_const = self.is_const() if self.is_pointer() else parent_const

            conversions.extend(struct.get_conversions(unwrap, is_const))

            for conv in [False]:
                if struct.needs_conversion(conv, unwrap, Direction.INPUT, is_const):
                    conversions.append(StructConversionFunction(struct, Direction.INPUT, conv, unwrap, is_const))
                if struct.needs_conversion(conv, unwrap, Direction.OUTPUT, is_const):
                    conversions.append(StructConversionFunction(struct, Direction.OUTPUT, conv, unwrap, is_const))

            if struct.name in STRUCT_COPY:
                conversions.append(StructConversionFunction(struct, Direction.INPUT, False, unwrap, is_const, True))

        if self.is_static_array() or self.is_dynamic_array():
            for conv in [False]:
                if self.needs_conversion(conv, unwrap, Direction.INPUT, parent_const):
                    conversions.append(ArrayConversionFunction(self, Direction.INPUT, conv, unwrap))
                if self.needs_conversion(conv, unwrap, Direction.OUTPUT, parent_const):
                    conversions.append(ArrayConversionFunction(self, Direction.OUTPUT, conv, unwrap))

        return conversions

    def needs_ptr32_type(self):
        """ Check if variable needs to use PTR32 type. """

        return self.is_pointer() or self.is_pointer_size() or self.is_static_array()

    def value(self, prefix, conv):
        """ Returns code accessing member value, casting 32-bit pointers when needed. """

        if not conv or not self.needs_ptr32_type() or (not self.is_pointer() and self.type == "size_t"):
            return prefix + self.name

        cast_type = ""
        if self.const:
            cast_type += "const "

        if self.pointer_array or ((self.is_pointer() or self.is_static_array()) and self.is_pointer_size()):
            cast_type += "PTR32 *"
        else:
            cast_type += self.type
            if self.needs_win32_type():
                cast_type += "32"

            if self.is_pointer():
                cast_type += " {0}".format(self.pointer)
            elif self.is_static_array():
                cast_type += " *"

        return "({0})UlongToPtr({1}{2})".format(cast_type, prefix, self.name)


class XrMember(XrVariable):
    def __init__(self, const=False, struct_fwd_decl=False,_type=None, pointer=None, name=None, array_len=None,
                 dyn_array_len=None, optional=False, values=None, object_type=None, bit_width=None,
                 returnedonly=False, parent=None, selection=None, selector=None):
        XrVariable.__init__(self, const=const, type=_type, name=name, pointer=pointer, array_len=array_len,
                            dyn_array_len=dyn_array_len, object_type=object_type, optional=optional,
                            returnedonly=returnedonly, parent=parent, selection=selection, selector=selector)
        self.struct_fwd_decl = struct_fwd_decl
        self.values = values
        self.bit_width = bit_width

    def __repr__(self):
        return "{0} {1} {2} {3} {4} {5} {6}".format(self.const, self.struct_fwd_decl, self.type, self.pointer,
                self.name, self.array_len, self.dyn_array_len)

    @staticmethod
    def from_xml(member, returnedonly, parent):
        """ Helper function for parsing a member tag within a struct or union. """

        if not api_is_openxr(member):
            return None

        name_elem = member.find("name")
        type_elem = member.find("type")

        const = False
        struct_fwd_decl = False
        member_type = None
        pointer = None
        array_len = None
        bit_width = None

        values = member.get("values")

        if member.text:
            if "const" in member.text:
                const = True

            # Some members contain forward declarations:
            # - XrBaseInstructure has a member "const struct XrBaseInStructure *next"
            # - XrWaylandSurfaceCreateInfoKHR has a member "struct wl_display *display"
            if "struct" in member.text:
                struct_fwd_decl = True

        if type_elem is not None:
            member_type = type_elem.text
            if type_elem.tail is not None:
                pointer = type_elem.tail.strip() if type_elem.tail.strip() != "" else None

        # Name of other member within, which stores the number of
        # elements pointed to be by this member.
        dyn_array_len = member.get("len")

        # Some members are optional, which is important for conversion code e.g. not dereference NULL pointer.
        optional = True if member.get("optional") else False

        # Usually we need to allocate memory for dynamic arrays. We need to do the same in a few other cases
        # like for XrCommandBufferBeginInfo.pInheritanceInfo. Just threat such cases as dynamic arrays of
        # size 1 to simplify code generation.
        if dyn_array_len is None and pointer is not None:
            dyn_array_len = 1

        # Some members are arrays, attempt to parse these. Formats include:
        # <member><type>char</type><name>extensionName</name>[<enum>XR_MAX_EXTENSION_NAME_SIZE</enum>]</member>
        # <member><type>uint32_t</type><name>foo</name>[4]</member>
        if name_elem.tail and name_elem.tail[0] == '[':
            LOGGER.debug("Found array type")
            enum_elem = member.find("enum")
            if enum_elem is not None:
                array_len = enum_elem.text
            else:
                # Remove brackets around length
                array_len = name_elem.tail.strip("[]")

        object_type = member.get("objecttype", None)

        # Some members are bit field values:
        # <member><type>uint32_t</type> <name>mask</name>:8</member>
        if name_elem.tail and name_elem.tail[0] == ':':
            LOGGER.debug("Found bit field")
            bit_width = int(name_elem.tail[1:])

        selection = member.get("selection").split(',') if member.get("selection") else None
        selector = member.get("selector", None)

        return XrMember(const=const, struct_fwd_decl=struct_fwd_decl, _type=member_type, pointer=pointer,
                        name=name_elem.text, array_len=array_len, dyn_array_len=dyn_array_len, optional=optional,
                        values=values, object_type=object_type, bit_width=bit_width, returnedonly=returnedonly,
                        parent=parent, selection=selection, selector=selector)

    def copy(self, input, output, direction, conv, unwrap, copy):
        """ Helper method for use by conversion logic to generate a C-code statement to copy this member.
            - `conv` indicates whether the statement is in a struct alignment conversion path. """

        win_type = "win32" if conv else "win64"
        suffix = convert_suffix(direction, win_type, unwrap, self.is_wrapped())

        if self.needs_conversion(conv, unwrap, direction, False):
            if self.is_dynamic_array():
                # Array length is either a variable name (string) or an int.
                count = self.get_dyn_array_len(input, conv)
                pointer_part = "pointer_" if self.pointer_array else ""
                if direction == Direction.OUTPUT:
                    return "convert_{2}_{6}array_{5}({3}{1}, {0}, {4});\n".format(self.value(output, conv),
                        self.name, self.type, input, count, suffix, pointer_part)
                else:
                    return "{0}{1} = convert_{2}_{6}array_{5}(ctx, {3}, {4});\n".format(output,
                        self.name, self.type, self.value(input, conv), count, suffix, pointer_part)
            elif self.is_static_array():
                count = self.array_len
                if direction == Direction.OUTPUT:
                    # Needed by XrMemoryHeap.memoryHeaps
                    return "convert_{0}_array_{5}({2}{1}, {3}{1}, {4});\n".format(self.type,
                        self.name, input, output, count, suffix)
                else:
                    # Nothing needed this yet.
                    LOGGER.warn("TODO: implement copying of static array for {0}.{1}".format(self.type, self.name))
            elif self.is_handle() and self.is_wrapped():
                handle = self.type_info["data"]
                if direction == Direction.OUTPUT:
                    LOGGER.error("OUTPUT parameter {0}.{1} cannot be unwrapped".format(self.type, self.name))
                elif self.optional:
                    return "{0}{1} = {2} ? {3} : 0;\n".format(output, self.name, self.value(input, conv),
                        handle.unwrap_handle(self.value(input, conv), unwrap))
                else:
                    input_name = "{0}{1}".format(input, self.name)
                    return "{0}{1} = {2} ? {3} : XR_NULL_HANDLE;\n".format(output, self.name,
                        input_name, handle.unwrap_handle(self.value(input, conv), unwrap))
            elif self.is_generic_handle():
                if direction == Direction.OUTPUT:
                    LOGGER.error("OUTPUT parameter {0}.{1} cannot be unwrapped".format(self.type, self.name))
                if unwrap == Unwrap.DRIVER and self.is_wrapped(Unwrap.DRIVER):
                    LOGGER.error("DRIVER unwrapping of {0}.{1} not implemented".format(self.type, self.name))
                return "{0}{1} = wine_xr_unwrap_handle({2}{3}, {2}{1});\n".format(output, self.name, input, self.object_type)
            else:
                selector_part = ", {0}{1}".format(input, self.selector) if self.selector else ""
                if direction == Direction.OUTPUT:
                    return "convert_{0}_{4}(&{2}{1}, &{3}{1}{5});\n".format(self.type,
                        self.name, input, output, suffix, selector_part)
                else:
                    ctx_param = "ctx, " if self.needs_alloc(conv, unwrap) else ""
                    return "convert_{0}_{4}({5}&{2}{1}, &{3}{1}{6});\n".format(self.type,
                        self.name, input, output, suffix, ctx_param, selector_part)
        elif self.is_static_array():
            bytes_count = "{0} * sizeof({1})".format(self.array_len, self.type)
            return "memcpy({0}{1}, {2}{1}, {3});\n".format(output, self.name, input, bytes_count)
        elif self.is_dynamic_array() and copy:
            if self.type == "void":
                return "MEMDUP_VOID(ctx, {0}{1}, {2}{1}, {3});\n".format(output, self.name, input, self.get_dyn_array_len(input, conv))
            else:
                return "MEMDUP(ctx, {0}{1}, {2}{1}, {3});\n".format(output, self.name, input, self.get_dyn_array_len(input, conv))
        elif direction == Direction.INPUT:
            return "{0}{1} = {2};\n".format(output, self.name, self.value(input, conv))
        elif conv and direction == Direction.OUTPUT and self.is_pointer():
            return "{0}{1} = PtrToUlong({2}{1});\n".format(output, self.name, input)
        else:
            return "{0}{1} = {2}{1};\n".format(output, self.name, input)

    def definition(self, align=False, conv=False):
        """ Generate prototype for given function.

        Args:
            align (bool, optional): Enable alignment if a type needs it. This adds WINE_XR_ALIGN(8) to a member.
            conv (bool, optional): Enable conversion if a type needs it. This appends '_host' to the name.
        """

        if conv and (self.is_pointer() or self.is_pointer_size()):
            text = "PTR32 " + self.name
            if self.is_static_array():
                text += "[{0}]".format(self.array_len)
            return text

        text = ""
        if self.is_const():
            text += "const "

        if self.is_struct_forward_declaration():
            text += "struct "

        text += self.type
        if conv and self.needs_win32_type():
            text += "32"

        if self.is_pointer():
            text += " {0}{1}".format(self.pointer, self.name)
        else:
            if align and self.needs_alignment():
                if conv:
                    text += " DECLSPEC_ALIGN(8) " + self.name
                else:
                    text += " WINE_XR_ALIGN(8) " + self.name
            else:
                text += " " + self.name

        if self.is_static_array():
            text += "[{0}]".format(self.array_len)

        if self.is_bit_field():
            text += ":{}".format(self.bit_width)

        return text

    def is_struct_forward_declaration(self):
        return self.struct_fwd_decl

    def is_bit_field(self):
        return self.bit_width is not None

    def needs_conversion(self, conv, unwrap, direction, struct_const):
        """ Check if member needs conversion. """

        # we can't convert unions if we don't have a selector
        if self.is_union() and not self.selector:
            return False

        is_const = self.is_const() if self.is_pointer() else struct_const

        # const members don't needs output conversion unless they are structs with non-const pointers
        if direction == Direction.OUTPUT and is_const and not self.is_struct():
            return False

        if direction == Direction.INPUT:
            # returnedonly members don't needs input conversions
            if not self.is_pointer() and self.returnedonly:
                return False
            # pointer arrays always need input conversion
            if conv and self.is_dynamic_array() and self.pointer_array:
                return True

        if self.is_handle():
            if unwrap != Unwrap.NONE and self.handle.is_wrapped():
                return True
            if conv and self.handle.is_dispatchable():
                return True
        elif self.is_generic_handle():
            if unwrap != Unwrap.NONE:
                return True
        elif self.is_struct() or self.is_union():
            if self.struct.needs_conversion(conv, unwrap, direction, is_const):
                return True

        # if pointer member needs output conversion, it also needs input conversion
        # to allocate the pointer
        if direction == Direction.INPUT and self.is_pointer() and \
           self.needs_conversion(conv, unwrap, Direction.OUTPUT, struct_const):
            return True

        return False

class XrParam(XrVariable):
    """ Helper class which describes a parameter to a function call. """

    def __init__(self, type_info, const=None, pointer=None, name=None, parent=None, array_len=None,
                 dyn_array_len=None, object_type=None, optional=False):
        XrVariable.__init__(self, const=const, type_info=type_info, type=type_info["name"], name=name,
                            pointer=pointer, array_len=array_len, dyn_array_len=dyn_array_len,
                            object_type=object_type, optional=optional, parent=parent)

        self._set_format_string()

    def __repr__(self):
        return "{0} {1} {2} {3} {4} {5}".format(self.const, self.type, self.pointer, self.name, self.array_len, self.dyn_array_len)

    @staticmethod
    def from_xml(param, types, parent):
        """ Helper function to create XrParam from xml. """

        if not api_is_openxr(param):
            return None

        # Parameter parsing is slightly tricky. All the data is contained within
        # a param tag, but some data is within subtags while others are text
        # before or after the type tag.
        # Common structure:
        # <param>const <type>char</type>* <name>pLayerName</name></param>

        name_elem = param.find("name")
        array_len = None
        name = name_elem.text
        # Tail contains array length e.g. for blendConstants param of xrSetBlendConstants
        if name_elem.tail is not None:
            array_len = name_elem.tail.strip("[]")

        # Name of other parameter in function prototype, which stores the number of
        # elements pointed to be by this parameter.
        dyn_array_len = param.get("len", None)

        const = param.text.strip() if param.text else None
        type_elem = param.find("type")
        pointer = type_elem.tail.strip() if type_elem.tail.strip() != "" else None

        attr = param.get("optional")
        optional = attr and attr.startswith("true")

        # Some uint64_t are actually handles with a separate type param
        object_type = param.get("objecttype", None)

        # Since we have parsed all types before hand, this should not happen.
        type_info = types.get(type_elem.text, None)
        if type_info is None:
            LOGGER.error("type info not found for: {0}".format(type_elem.text))

        return XrParam(type_info, const=const, pointer=pointer, name=name, array_len=array_len,
                       dyn_array_len=dyn_array_len, object_type=object_type, optional=optional,
                       parent=parent)

    def _set_format_string(self):
        """ Internal helper function to be used by constructor to set format string. """

        # Determine a format string used by code generation for traces.
        # 64-bit types need a conversion function.
        self.format_conv = None
        if self.is_static_array() or self.is_pointer() or self.type in [ "XrFutureEXT" ]:
            self.format_str = "%p"
        else:
            if self.type_info["category"] in ["bitmask"]:
                # Since 1.2.170 bitmasks can be 32 or 64-bit, check the basetype.
                if self.type_info["data"].type == "XrFlags64":
                    self.format_str = "0x%s"
                    self.format_conv = "wine_dbgstr_longlong({0})"
                else:
                    self.format_str = "%#x"
            elif self.type_info["category"] in ["enum"]:
                self.format_str = "%#x"
            elif self.is_handle():
                # We use uint64_t for non-dispatchable handles as opposed to pointers
                # for dispatchable handles.
                if self.handle.is_dispatchable():
                    self.format_str = "%p"
                else:
                    self.format_str = "0x%s"
                    self.format_conv = "wine_dbgstr_longlong({0})"
            elif self.type == "float":
                self.format_str = "%f"
            elif self.type == "int":
                self.format_str = "%d"
            elif self.type == "int32_t":
                self.format_str = "%d"
            elif self.type == "size_t":
                self.format_str = "0x%s"
                self.format_conv = "wine_dbgstr_longlong({0})"
            elif self.type in ["uint16_t", "uint32_t", "XrBool32"]:
                self.format_str = "%u"
            elif self.type in ["uint64_t","XrAsyncRequestIdFB"]:
                self.format_str = "0x%s"
                self.format_conv = "wine_dbgstr_longlong({0})"
            elif self.type in ["HANDLE", "VkInstance"]:
                self.format_str = "%p"
            elif self.type in ["XrSystemId", "XrPath", "XrTime", "XrDuration", "XrControllerModelKeyMSFT", "XrMarkerML", "XrSpatialEntityIdBD"]:
                self.format_str = "0x%s"
                self.format_conv = "wine_dbgstr_longlong({0})"
            elif self.type in ["XrVector2f"]:
                self.format_str = "%f, %f"
                self.format_conv = "{0}.x, {0}.y"
            elif self.type in ["XrPosef"]:
                self.format_str = "{{%f, %f, %f, %f}, {%f %f %f}}"
                self.format_conv = "{0}.orientation.x, {0}.orientation.y, {0}.orientation.z, {0}.orientation.w, {0}.position.x, {0}.position.y, {0}.position.z"
            elif self.type in ["VisualID", "xcb_visualid_t"]:
                # Don't care about specific types for non-Windows platforms.
                self.format_str = ""
            else:
                LOGGER.warn("Unhandled type: {0}".format(self.type_info))

    def copy(self, direction, conv, unwrap, prefix=""):
        win_type = "win32" if conv else "win64"
        suffix = convert_suffix(direction, win_type, unwrap, self.is_wrapped())

        if direction == Direction.INPUT:
            ctx_param = "ctx, " if self.needs_alloc(conv, unwrap) else ""
            if self.is_dynamic_array():
                return "    {0}_host = convert_{2}_array_{4}({5}{1}, {3});\n".format(self.name, self.value(prefix, conv),
                    self.type, self.get_dyn_array_len(prefix, conv), suffix, ctx_param)
            elif self.optional:
                ret  = "    if ({0}{1})\n".format(prefix, self.name)
                ret += "    {\n"
                ret += "        {0}_host = conversion_context_alloc(ctx, sizeof(*{0}_host));\n".format(self.name)
                ret += "        convert_{0}_{3}({4}{1}, {2}_host);\n".format(self.type, self.value(prefix, conv),
                    self.name, suffix, ctx_param)
                ret += "    }\n"
                return ret
            elif self.is_struct():
                return "    convert_{0}_{3}({4}{1}, &{2}_host);\n".format(self.type, self.value(prefix, conv),
                    self.name, suffix, ctx_param)
            elif self.is_pointer_size() and self.type != "size_t":
                return "    {0}_host = UlongToPtr(*{1});\n".format(self.name, self.value(prefix, conv))
            else:
                return "    {0}_host = *{1};\n".format(self.name, self.value(prefix, conv))
        else:
            if self.is_dynamic_array():
                return "    convert_{0}_array_{1}({2}_host, {3}, {4});\n".format(
                    self.type, suffix, self.name, self.value(prefix, conv),
                    self.get_dyn_array_len(prefix, conv))
            elif self.is_struct():
                ref_part = "" if self.optional else "&"
                return "    convert_{0}_host_to_{3}({4}{2}_host, {1});\n".format(
                    self.type, self.value(prefix, conv), self.name, win_type, ref_part)
            elif self.is_pointer_size() and self.type != "size_t":
                return "    *{0} = PtrToUlong({1}_host);\n".format(self.value(prefix, conv), self.name)
            else:
                return "    *{0} = {1}_host;\n".format(self.value(prefix, conv), self.name)

    def definition(self, postfix=None, is_member=False, conv=False):
        """ Return prototype for the parameter. E.g. 'const char *foo' """

        if is_member and conv and self.needs_ptr32_type():
            return "PTR32 {0}".format(self.name)

        proto = ""
        if self.const and (not is_member or self.pointer):
            proto += self.const + " "

        proto += self.type
        name = self.name
        if conv and self.needs_win32_type():
            proto += "32"

        if is_member and self.needs_alignment():
            proto += " DECLSPEC_ALIGN(8)"

        if self.is_pointer():
            proto += " {0}{1}".format(self.pointer, name)
        elif is_member and self.is_static_array():
            proto += " *" + name
        else:
            proto += " " + name

        # Allows appending something to the variable name useful for
        # win32 to host conversion.
        if postfix is not None:
            proto += postfix

        if not is_member and self.is_static_array():
            proto += "[{0}]".format(self.array_len)

        return proto

    def format_string(self, conv):
        if conv and self.needs_ptr32_type() and (self.type != "size_t" or self.is_pointer()):
            return "%#x"
        return self.format_str

    def is_dispatchable(self):
        if not self.is_handle():
            return False

        return self.handle.is_dispatchable()

    def needs_conversion(self, conv, unwrap, direction, parent_const=False):
        """ Check if param needs conversion. """

        if self.is_struct():
            return self.struct.needs_conversion(conv, unwrap, direction, self.is_const())

        if self.is_handle():
            # non-pointer handles are handled inline in thunks
            if not self.is_dynamic_array() and not self.is_static_array():
                return conv and self.is_pointer() and self.handle.is_dispatchable()

            # xrAllocateCommandBuffers is a special case, we use it in our private thunk as an input param
            param_direction = (Direction.INPUT if self.is_const() else Direction.OUTPUT)
            if self.name == "pCommandBuffers":
                param_direction = Direction.INPUT
            if direction != param_direction:
                return False

            if unwrap != Unwrap.NONE and self.handle.is_wrapped():
                return True
            if conv and self.handle.is_dispatchable():
                return True
        elif self.is_pointer() and self.is_pointer_size():
            return conv

        return False

    def needs_variable(self, conv, unwrap):
        if self.needs_conversion(conv, unwrap, Direction.INPUT):
            return True
        if self.needs_conversion(conv, unwrap, Direction.OUTPUT):
            return True
        return False

    def spec(self):
        """ Generate spec file entry for this parameter. """

        if self.is_pointer() and self.type == "char":
            return "str"
        if self.is_dispatchable() or self.is_pointer() or self.is_static_array():
            return "ptr"
        if self.type_info["category"] in ["bitmask"]:
            # Since 1.2.170 bitmasks can be 32 or 64-bit, check the basetype.
            if self.type_info["data"].type == "XrFlags64":
                return "int64"
            else:
                return "long"
        if self.type_info["category"] in ["enum"]:
            return "long"
        if self.is_handle() and not self.is_dispatchable():
            return "int64"
        if self.type == "float":
            return "float"
        if self.type in ["int", "int32_t", "size_t", "uint16_t", "uint32_t", "XrBool32"]:
            return "long"
        if self.type in ["uint64_t", "XrDeviceSize"]:
            return "int64"

        LOGGER.error("Unhandled spec conversion for type: {0}".format(self.type))

    def variable(self, conv, unwrap, params_prefix=""):
        """ Returns 'glue' code during generation of a function call on how to access the variable.
        This function handles various scenarios such as 'unwrapping' if dispatchable objects and
        renaming of parameters in case of win32 -> host conversion.

        Args:
            conv (bool, optional): Enable conversion if the param needs it. This appends '_host' to the name.
        """

        # Hack until we enable allocation callbacks from ICD to application. These are a joy
        # to enable one day, because of calling convention conversion.
        if unwrap != Unwrap.NONE and "XrAllocationCallbacks" in self.type:
            LOGGER.debug("TODO: setting NULL XrAllocationCallbacks for {0}".format(self.name))
            return "NULL"

        if self.needs_variable(conv, unwrap):
            if self.is_dynamic_array() or self.optional:
                return "{0}_host".format(self.name)
            else:
                return "&{0}_host".format(self.name)

        p = self.value(params_prefix, conv)

        if unwrap != Unwrap.NONE:
            unwrap_handle = None
            if self.object_type != None and self.type == "uint64_t":
                if unwrap == Unwrap.DRIVER and self.is_wrapped(Unwrap.DRIVER):
                    LOGGER.error("DRIVER unwrapping of {0}.{1} not implemented".format(self.type, self.name))
                unwrap_handle = "wine_xr_unwrap_handle({0}{1}, {0}{2})".format(
                    params_prefix, self.object_type, self.name)

            elif self.is_handle():
                # We need to pass the host handle to the host Vulkan calls and
                # the wine driver's handle to calls which are wrapped by the driver.
                unwrap_handle = self.handle.unwrap_handle(p, unwrap)
            if unwrap_handle:
                if self.optional:
                    unwrap_handle = "{0}{1} ? {2} : 0".format(params_prefix, self.name, unwrap_handle)
                return unwrap_handle

        return p


class XrStruct(Sequence):
    """ Class which represents the type union and struct. """

    def __init__(self, name, members, returnedonly, structextends, alias=None, union=False):
        self.name = name
        self.members = members
        self.returnedonly = returnedonly
        self.structextends = structextends
        self.required = False
        self.alias = alias
        self.union = union
        self.type_info = None # To be set later.
        self.struct_extensions = []
        self.aliased_by = []

    def __getitem__(self, i):
        return self.members[i]

    def __len__(self):
        return len(self.members)

    @staticmethod
    def from_alias(struct, alias):
        name = struct.attrib.get("name")
        aliasee = XrStruct(name, alias.members, alias.returnedonly, alias.structextends, alias=alias)

        alias.add_aliased_by(aliasee)
        return aliasee

    @staticmethod
    def from_xml(struct):
        if not api_is_openxr(struct):
            return None

        # Unions and structs are the same parsing wise, but we need to
        # know which one we are dealing with later on for code generation.
        union = True if struct.attrib["category"] == "union" else False

        name = struct.attrib.get("name")

        # 'Output' structures for which data is filled in by the API are
        # marked as 'returnedonly'.
        returnedonly = True if struct.attrib.get("returnedonly") else False

        # Those structs seem to be broken in spec, they are specified as
        # returned only, but documented as input structs.
        if name in ["XrPipelineShaderStageRequiredSubgroupSizeCreateInfo"]:
            returnedonly = False

        # Those structs don't have returnedonly in spec, but they could (should?).
        if name in ["XrSurfaceCapabilitiesPresentBarrierNV"]:
            returnedonly = True

        structextends = struct.attrib.get("structextends")
        structextends = structextends.split(",") if structextends else []

        s = XrStruct(name, [], returnedonly, structextends, union=union)
        for member in struct.findall("member"):
            xr_member = XrMember.from_xml(member, returnedonly, s)
            if xr_member:
                s.members.append(xr_member)

        return s

    @staticmethod
    def decouple_structs(structs):
        """ Helper function which decouples a list of structs.
        Structures often depend on other structures. To make the C compiler
        happy we need to define 'substructures' first. This function analyzes
        the list of structures and reorders them in such a way that they are
        decoupled.
        """

        tmp_structs = list(structs) # Don't modify the original structures.
        decoupled_structs = []

        while (len(tmp_structs) > 0):
            # Iterate over a copy because we want to modify the list inside the loop.
            for struct in list(tmp_structs):
                dependends = False

                if not struct.required:
                    tmp_structs.remove(struct)
                    continue

                for m in struct:
                    if not (m.is_struct() or m.is_union()):
                        continue

                    # XrBaseInstructure and XrBaseOutStructure reference themselves.
                    if m.type == struct.name:
                        break

                    found = False
                    # Check if a struct we depend on has already been defined.
                    for s in decoupled_structs:
                        if s.name == m.type:
                            found = True
                            break

                    if not found:
                        # Check if the struct we depend on is even in the list of structs.
                        # If found now, it means we haven't met all dependencies before we
                        # can operate on the current struct.
                        # When generating 'host' structs we may not be able to find a struct
                        # as the list would only contain the structs requiring conversion.
                        for s in tmp_structs:
                            if s.name == m.type:
                                dependends = True
                                break

                if dependends == False:
                    decoupled_structs.append(struct)
                    tmp_structs.remove(struct)

        return decoupled_structs

    def definition(self, align=False, conv=False):
        """ Convert structure to textual definition.

        Args:
            align (bool, optional): enable alignment to 64-bit for win32 struct compatibility.
            conv (bool, optional): enable struct conversion if the struct needs it.
            postfix (str, optional): text to append to end of struct name, useful for struct renaming.
        """

        if self.is_alias():
            return ""

        suffix = "32" if conv else ""
        if self.union:
            text = "typedef union {0}".format(self.name)
        else:
            text = "typedef struct {0}".format(self.name)
        text += suffix

        text += "\n{\n"

        for m in self:
            if align and m.needs_alignment():
                text += "    {0};\n".format(m.definition(align=align, conv=conv))
            else:
                text += "    {0};\n".format(m.definition(conv=conv))

        text += "}} {0}{1};\n".format(self.name, suffix)

        for aliasee in self.aliased_by:
            text += "typedef {0}{2} {1}{2};\n".format(self.name, aliasee.name, suffix)

        return text

    def is_alias(self):
        return bool(self.alias)

    def add_aliased_by(self, aliasee):
        self.aliased_by.append(aliasee)

    def needs_alignment(self):
        """ Check if structure needs alignment for 64-bit data.
        Various structures need alignment on 64-bit variables due
        to compiler differences on 32-bit between Win32 and Linux.
        """

        for m in self.members:
            if self.name == m.type:
                continue
            if m.needs_alignment():
                return True
        return False

    def is_wrapped(self):
        """ Returns if struct members need unwrapping of handle. """

        for m in self.members:
            if self.name == m.type:
                continue
            if m.is_wrapped():
                return True
        return False

    def needs_extensions_conversion(self, conv, direction):
        """ Check if struct contains extensions chain that needs to be converted """

        if direction == Direction.INPUT and self.name in STRUCT_CHAIN_CONVERSIONS:
            return True

        if not "next" in self:
            return False
        is_const = self.members[self.members.index("next")].is_const()
        # XrOpticalFlowSessionCreateInfoNV is missing const in its next pointer
        if self.name in ["XrOpticalFlowSessionCreateInfoNV",
                         "XrDescriptorBufferBindingInfoEXT"]:
            is_const = True

        for e in self.struct_extensions:
            if not e.required:
                continue
            if e.needs_conversion(conv, Unwrap.HOST, direction, is_const, check_extensions=False):
                return True
            if direction == Direction.INPUT:
                # we need input conversion of structs containing struct chain even if it's returnedonly,
                # so that we have a chance to allocate buffers
                if e.needs_conversion(conv, Unwrap.HOST, Direction.OUTPUT, is_const, check_extensions=False):
                    return True

        return False

    def needs_conversion(self, conv, unwrap, direction, is_const, check_extensions=True):
        """ Check if struct needs conversion. """

        # XrAllocationCallbacks never needs conversion
        if self.name == "XrAllocationCallbacks":
            return False

        # pFixedRateFlags field is missing const, but it doesn't need output conversion
        if direction == Direction.OUTPUT and self.name == "XrImageCompressionControlEXT":
            return False

        needs_output_copy = False

        for m in self.members:
            if self.name == m.type:
                continue

            if m.name == "next":
                # next is a pointer, so it always needs conversion
                if conv and direction == Direction.INPUT:
                    return True
                # we need input conversion of structs containing struct chain even if it's returnedonly
                if direction == Direction.INPUT and \
                   self.needs_conversion(conv, unwrap, Direction.OUTPUT, is_const):
                    return True
                continue

            # for non-pointer members, check for returnedonly and const attributes
            if not m.is_pointer() or m.type == "void":
                if direction == Direction.INPUT:
                    if self.returnedonly:
                        continue
                else:
                    if is_const or m.is_const():
                        continue

            # check alignment and pointer-sized members for 32-bit conversions
            if conv and (direction == Direction.INPUT or not is_const):
                if m.is_pointer() or m.is_pointer_size():
                    return True
                # we don't check structs here, they will will be traversed by needs_conversion chain anyway
                if not m.is_struct() and m.needs_alignment():
                    return True

            if m.needs_conversion(conv, unwrap, direction, is_const):
                return True

            # pointers will be handled by needs_conversion, but if we have any other non-const
            # member, we may need to copy output
            if direction == Direction.OUTPUT and not m.is_pointer() and not is_const and not m.is_const():
                needs_output_copy = True

        # if output needs any copy and we need input conversion, then we also need output conversion
        if needs_output_copy and self.needs_conversion(conv, unwrap, Direction.INPUT, check_extensions):
            return True

        return check_extensions and self.needs_extensions_conversion(conv, direction)

    def needs_alloc(self, conv, unwrap):
        """ Check if any struct member needs some memory allocation."""

        if self.needs_extensions_conversion(conv, Direction.INPUT):
            return True

        for m in self.members:
            if self.name == m.type:
                continue
            if m.needs_alloc(conv, unwrap):
                return True

        return False

    def needs_win32_type(self):
        # XrAllocationCallbacks never needs conversion
        if self.name == "XrAllocationCallbacks":
            return False

        for m in self.members:
            if self.name == m.type:
                continue
            if m.is_pointer() or m.is_pointer_size():
                return True
            if m.needs_alignment():
                return True
            if (m.is_struct() or m.is_union()) and m.struct.needs_win32_type():
                return True

    def set_type_info(self, types):
        """ Helper function to set type information from the type registry.
        This is needed, because not all type data is available at time of
        parsing.
        """
        for m in self.members:
            type_info = types[m.type]
            m.set_type_info(type_info)

    def get_conversions(self, unwrap, parent_const):
        conversions = []

        # Collect any conversion for any extension structs.
        for e in self.struct_extensions:
            if not e.required:
                continue
            conversions.extend(e.get_conversions(Unwrap.HOST, parent_const))

        # Collect any conversion for any member structs.
        for m in self:
            if m.type == self.name:
                continue
            conversions.extend(m.get_conversions(unwrap, parent_const))

        return conversions


class StructConversionFunction(object):
    def __init__(self, struct, direction, conv, unwrap, const, copy=False):
        self.direction = direction
        self.operand = struct
        self.type = struct.name
        self.conv = conv
        self.unwrap = unwrap
        self.const = const
        self.copy = copy

        if copy:
            name = "copy_{0}".format(self.type)
        else:
            name = "convert_{0}_".format(self.type)
            win_type = "win32" if self.conv else "win64"
            name += convert_suffix(direction, win_type, unwrap, struct.is_wrapped())
        self.name = name

    def __eq__(self, other):
        return self.name == other.name

    def member_needs_copy(self, struct, m):
        if self.direction == Direction.OUTPUT:
            if m.name in ["type", "next"]:
                return False
            if self.const and not m.is_pointer():
                return False
            if m.is_const() and not m.needs_conversion(self.conv, self.unwrap, Direction.OUTPUT, self.const):
                return False
        else:
            if m.name == "next":
                return True
            if m.name != "type" and struct.returnedonly and not m.needs_conversion(
                    self.conv, self.unwrap, Direction.INPUT, self.const):
                return False
        return True

    def definition(self):
        """ Helper function for generating a struct conversion function. """

        # It doesn't make sense to generate conversion functions for non-struct variables
        # which aren't in arrays, as this should be handled by the copy() function
        if not isinstance(self.operand, XrStruct):
            return ""

        body = ""

        if not self.conv and not self.copy:
            body += "#ifdef _WIN64\n"

        needs_alloc = self.direction != Direction.OUTPUT and self.operand.needs_alloc(self.conv, self.unwrap)
        win_type = self.type
        if self.conv and self.operand.needs_win32_type():
            win_type += "32"
        if self.direction == Direction.OUTPUT and self.const:
            win_type = "const " + win_type

        if self.copy:
            body += "void {0}(".format(self.name)
        else:
            body += "static inline void {0}(".format(self.name)
        if self.conv:

            if self.direction == Direction.OUTPUT:
                params = ["const {0} *in".format(self.type), "{0} *out".format(win_type)]
            else:
                params = ["const {0} *in".format(win_type), "{0} *out".format(self.type)]

            if self.operand.union:
                params.append("XrFlags selector")

            # Generate parameter list
            if needs_alloc:
                body += "struct conversion_context *ctx, "
            body += ", ".join(p for p in params)
            body += ")\n"

        else:
            params = ["const {0} *in".format(self.type), "{0} *out".format(self.type)]

            # Generate parameter list
            if needs_alloc:
                body += "struct conversion_context *ctx, "
            body += ", ".join(p for p in params)
            body += ")\n"

        needs_extensions = self.operand.needs_extensions_conversion(self.conv, self.direction)

        body += "{\n"
        if needs_extensions:
            if self.direction == Direction.INPUT:
                if self.conv:
                    body += "    const XrBaseInStructure32 *in_header;\n"
                else:
                    body += "    const XrBaseInStructure *in_header;\n"
                body += "    XrBaseOutStructure *out_header = (void *)out;\n\n"
            else:
                body += "    const XrBaseInStructure *in_header;\n"
                if self.conv:
                    body += "    XrBaseOutStructure32 *out_header = (void *)out;\n\n"
                else:
                    body += "    XrBaseOutStructure *out_header = (void *)out;\n\n"

        body += "    if (!in) return;\n\n"

        for m in self.operand:
            if not self.member_needs_copy(self.operand, m):
                continue
            if m.name == "next" and (needs_extensions or self.conv):
                body += "    out->next = NULL;\n"
                continue

            if m.selection:
                body += "    if ("
                body += " || ".join("selector == {}".format(s) for s in m.selection)
                body += ")\n    "

            body += "    " + m.copy("in->", "out->", self.direction, self.conv, self.unwrap, self.copy)

        if needs_extensions:
            if self.conv and self.direction == Direction.INPUT:
                body += "\n    for (in_header = UlongToPtr(in->next); in_header; in_header = UlongToPtr(in_header->next))\n"
            else:
                body += "\n    for (in_header = (void *)in->next; in_header; in_header = (void *)in_header->next)\n"
            body += "    {\n"
            body += "        switch (in_header->type)\n"
            body += "        {\n"

            ident = "            "

            if self.direction == Direction.INPUT and self.type in STRUCT_CHAIN_CONVERSIONS:
                has_any_chain_conversions = False
                for i in STRUCT_CHAIN_CONVERSIONS[self.type]:
                    body += "        case {0}:\n".format(i)
                    has_any_chain_conversions = True
                if has_any_chain_conversions:
                    body += ident + "break;\n"

            for ext in self.operand.struct_extensions:
                if not ext.required:
                    continue

                if self.direction == Direction.OUTPUT and not any([self.member_needs_copy(ext, m) for m in ext]):
                    continue


                stype = next(x for x in ext.members if x.name == "type").values
                if self.type in STRUCT_CHAIN_CONVERSIONS and stype in STRUCT_CHAIN_CONVERSIONS[self.type]:
                    continue
                win_type = ext.name + "32" if self.conv and ext.needs_win32_type() else ext.name
                if self.direction == Direction.INPUT:
                    in_type = "const " + win_type
                    out_type = ext.name
                else:
                    in_type = "const " + ext.name
                    out_type = win_type

                body += "        case {0}:\n".format(stype)
                body += "        {\n"
                if self.direction == Direction.INPUT:
                    body += ident + "{0} *out_ext = conversion_context_alloc(ctx, sizeof(*out_ext));\n".format(out_type)
                elif self.conv:
                    body += ident + "{0} *out_ext = find_next_struct32(out_header, {1});\n".format(out_type, stype)
                else:
                    body += ident + "{0} *out_ext = find_next_struct(out_header, {1});\n".format(out_type, stype)

                copy_body = ""

                for m in ext:
                    if m.name == "type":
                        copy_body += ident + "out_ext->type = {0};\n".format(stype)
                        continue
                    if not self.member_needs_copy(ext, m):
                        continue
                    if m.name == "next":
                        copy_body += ident + "out_ext->next = NULL;\n"
                        continue
                    copy_body += ident + m.copy("in_ext->", "out_ext->", self.direction, self.conv, Unwrap.HOST, self.copy)

                # Generate the definition of "in_ext" if we need it
                if "in_ext->" in copy_body:
                    body += ident + "{0} *in_ext = ({0} *)in_header;\n".format(in_type)
                body += copy_body

                if self.direction == Direction.INPUT:
                    body += ident + "out_header->next = (void *)out_ext;\n"
                body += ident + "out_header = (void *)out_ext;\n"
                body += ident + "break;\n"
                body += "        }\n"

            body += "        default:\n"
            if self.direction == Direction.INPUT:
                body += ident + "if ((in_header->type >> 16) == 0x7ead)\n"
                body += ident + "{\n"
                body += ident + "    XrBaseOutStructure *out_ext = conversion_context_alloc(ctx, 32);\n";
                body += ident + "    memcpy(out_ext, in_header, 32);\n";
                body += ident + "    out_ext->next = NULL;\n";
                body += ident + "    out_header->next = (void *)out_ext;\n";
                body += ident + "    out_header = (void *)out_ext;\n";
                body += ident + "}\n"
                body += ident + "else\n"
                body += ident + "{\n"
                body += ident + "    FIXME(\"Unhandled type %u.\\n\", in_header->type);\n"
                body += ident + "}\n"
            body += "            break;\n"
            body += "        }\n"
            body += "    }\n"
        elif self.conv and self.direction == Direction.INPUT and "next" in self.operand:
            body += "    if (in->next)\n"
            body += "        FIXME(\"Unexpected next\\n\");\n"

        body += "}\n"
        if not self.conv and not self.copy:
            body += "#endif /* _WIN64 */\n"
        body += "\n"

        return body


class ArrayConversionFunction(object):
    def __init__(self, array, direction, conv, unwrap):
        self.array = array
        self.direction = direction
        self.type = array.type
        self.conv = conv
        self.unwrap = unwrap

        if array.is_static_array() and direction == Direction.INPUT:
            LOGGER.error("Static array input conversion is not supported")

        name = "convert_{0}_".format(array.type)
        if array.pointer_array:
            name += "pointer_"
        name += "array_"
        win_type = "win32" if self.conv else "win64"
        name += convert_suffix(direction, win_type, unwrap, array.is_wrapped())
        self.name = name

    def __eq__(self, other):
        return self.name == other.name

    def definition(self):
        """ Helper function for generating a conversion function for array operands. """

        body = ""

        if not self.conv:
            body += "#ifdef _WIN64\n"

        needs_alloc = self.direction != Direction.OUTPUT and self.array.needs_alloc(self.conv, self.unwrap)

        win_type = self.type
        if self.conv:
            if self.array.needs_win32_type():
                win_type += "32"
            elif self.array.is_handle() and self.array.handle.is_dispatchable():
                win_type = "PTR32"
        if self.direction == Direction.OUTPUT and self.array.is_const():
            win_type = "const " + win_type
        pointer_part = self.array.pointer if self.array.pointer else "*"

        if self.direction == Direction.OUTPUT:
            params = ["const {0} {1}in".format(self.type, pointer_part),
                      "{0} {1}out".format(win_type, pointer_part), "uint32_t count"]
            return_type = None
        elif self.conv and self.array.pointer_array:
            params = ["const PTR32 *in", "uint32_t count"]
            return_type = self.type
        else:
            params = ["const {0} {1}in".format(win_type, pointer_part), "uint32_t count"]
            return_type = self.type

        needs_copy = not self.array.is_struct() or self.direction != Direction.INPUT or \
            not self.array.struct.returnedonly or "next" in self.array.struct

        # Generate function prototype.
        if return_type:
            body += "static inline {0}{1} {2}{3}(".format(
                "const " if self.array.is_const() else "", return_type, pointer_part, self.name)
        else:
            body += "static inline void {0}(".format(self.name)
        if needs_alloc:
            body += "struct conversion_context *ctx, "
        body += ", ".join(p for p in params)
        body += ")\n{\n"

        if return_type:
            body += "    {0} {1}out;\n".format(return_type, "**" if self.array.pointer_array else "*")
        if needs_copy:
            body += "    unsigned int i;\n\n"

        if return_type:
            body += "    if (!in || !count) return NULL;\n\n"
        else:
            body += "    if (!in) return;\n\n"

        if self.direction == Direction.INPUT:
            body += "    out = conversion_context_alloc(ctx, count * sizeof(*out));\n"

        if needs_copy:
            body += "    for (i = 0; i < count; i++)\n"
            body += "    {\n"

            if self.array.is_struct():
                struct = self.array.struct
                win_part = "win32" if self.conv else "win64"
                suffix = convert_suffix(self.direction, win_part, self.unwrap, struct.is_wrapped())
                ctx_part = ""
                if self.direction == Direction.INPUT and struct.needs_alloc(self.conv, self.unwrap):
                    ctx_part = "ctx, "

                if not self.array.pointer_array:
                    body += "        convert_{0}_{1}({2}&in[i], &out[i]);\n".format(
                        struct.name, suffix, ctx_part)
                else:
                    if struct.needs_conversion(self.conv, self.unwrap, self.direction, False):
                        body += "        if (in[i])\n"
                        body += "        {\n"
                        body += "            out[i] = conversion_context_alloc(ctx, sizeof(*out[i]));\n"
                        if self.conv:
                            in_param = "({0} *)UlongToPtr(in[i])".format(win_type)
                        else:
                            in_param = "in[i]"
                        body += "            convert_{0}_{1}({2}{3}, out[i]);\n".format(
                            struct.name, suffix, ctx_part, in_param)
                        body += "        }\n"
                        body += "        else\n"
                        body += "            out[i] = NULL;\n"
                    else:
                        body += "        out[i] = UlongToPtr(in[i]);\n"
            elif self.array.is_handle():
                if self.array.pointer_array:
                    LOGGER.error("Unhandled handle pointer arrays")
                handle = self.array.handle
                if not self.conv or not handle.is_dispatchable():
                    input = "in[i]"
                elif self.direction == Direction.INPUT:
                    input = "UlongToPtr(in[i])"
                else:
                    input = "PtrToUlong(in[i])"

                if self.unwrap == Unwrap.NONE or not handle.is_wrapped():
                    body += "        out[i] = {0};\n".format(input)
                elif self.direction == Direction.INPUT:
                    body += "        out[i] = {0};\n".format(handle.unwrap_handle(input, self.unwrap))
                else:
                    LOGGER.warning("Unhandled handle output conversion")
            elif self.array.pointer_array:
                body += "        out[i] = UlongToPtr(in[i]);\n"
            else:
                body += "        out[i] = in[i];\n"

            body += "    }\n"

        if return_type:
            body += "\n    return {0}out;\n".format("(void *)" if self.array.pointer_array else "")
        body += "}\n"

        if not self.conv:
            body += "#endif /* _WIN64 */\n"

        body += "\n"

        return body


class XrGenerator(object):
    def __init__(self, registry):
        self.registry = registry

        # Build a list conversion functions for struct conversion.
        self.conversions = []
        self.win32_structs = []
        for func in self.registry.funcs.values():
            if not func.needs_exposing():
                continue

            conversions = func.get_conversions()
            for conv in conversions:
                # Append if we don't already have this conversion.
                if not any(c == conv for c in self.conversions):
                    self.conversions.append(conv)

                if not isinstance(conv, StructConversionFunction):
                    continue

                for e in conv.operand.struct_extensions:
                    if not e.required or not e.needs_win32_type():
                        continue
                    if not any(s.name == e.name for s in self.win32_structs):
                        self.win32_structs.append(e)

                if not conv.operand.needs_win32_type():
                    continue

                # Structs can be used in different ways by different conversions
                # e.g. array vs non-array. Just make sure we pull in each struct once.
                if not any(s.name == conv.operand.name for s in self.win32_structs):
                    self.win32_structs.append(conv.operand)

    def _generate_copyright(self, f, spec_file=False):
        f.write("# " if spec_file else "/* ")
        f.write("Automatically generated from Vulkan xr.xml; DO NOT EDIT!\n")
        lines = ["", "This file is generated from Vulkan xr.xml file covered",
            "by the following copyright and permission notice:"]
        lines.extend([l.rstrip(" ") for l in self.registry.copyright.splitlines()])
        for line in lines:
            f.write("{0}{1}".format("# " if spec_file else " * ", line).rstrip(" ") + "\n")
        f.write("\n" if spec_file else " */\n\n")

    def generate_thunks_c(self, f):
        self._generate_copyright(f)

        f.write("#if 0\n")
        f.write("#pragma makedep unix\n")
        f.write("#endif\n\n")

        f.write("#include \"config.h\"\n\n")

        f.write("#include <stdlib.h>\n\n")

        f.write("#include \"openxr_private.h\"\n\n")

        f.write("WINE_DEFAULT_DEBUG_CHANNEL(openxr);\n\n")
        # Generate any conversion helper functions.
        for conv in self.conversions:
            f.write(conv.definition())

        # Create thunks for instance and device functions.
        # Global functions don't go through the thunks.
        for xr_func in self.registry.funcs.values():
            if not xr_func.needs_exposing():
                continue
            if xr_func.name in MANUAL_LOADER_FUNCTIONS:
                continue

            f.write(xr_func.thunk(prefix="thunk64_"))

        # Create array of extensions.
        f.write("static const char * const xr_extensions[] =\n{\n")
        for ext in self.registry.extensions:
            if ext["type"] != "instance":
                continue
            if ext["name"] in UNEXPOSED_EXTENSIONS:
                continue

            f.write("    \"{0}\",\n".format(ext["name"]))
        f.write("};\n\n")

        f.write("BOOL wine_xr_extension_supported(const char *name)\n")
        f.write("{\n")
        f.write("    unsigned int i;\n")
        f.write("    for (i = 0; i < ARRAY_SIZE(xr_extensions); i++)\n")
        f.write("    {\n")
        f.write("        if (strcmp(xr_extensions[i], name) == 0)\n")
        f.write("            return TRUE;\n")
        f.write("    }\n")
        f.write("    return FALSE;\n")
        f.write("}\n\n")

        f.write("BOOL wine_xr_is_type_wrapped(XrObjectType type)\n")
        f.write("{\n")
        f.write("    return FALSE")
        for handle in self.registry.handles:
            if not handle.is_required() or not handle.is_wrapped() or handle.is_alias():
                continue
            f.write(" ||\n        type == {}".format(handle.object_type))
        f.write(";\n")
        f.write("}\n\n")


        f.write("#ifdef _WIN64\n\n")

        f.write("const unixlib_entry_t __wine_unix_call_funcs[] =\n")
        f.write("{\n")
        f.write("    init_openxr,\n")
        for xr_func in self.registry.funcs.values():
            if not xr_func.needs_exposing():
                continue
            if xr_func.name in MANUAL_LOADER_FUNCTIONS:
                continue

            if xr_func.is_perf_critical():
                f.write("    (void *){1}{0},\n".format(xr_func.name, "thunk64_"))
            else:
                f.write("    {1}{0},\n".format(xr_func.name, "thunk64_"))
        f.write("};\n")
        f.write("C_ASSERT(ARRAYSIZE(__wine_unix_call_funcs) == unix_count);\n\n")

        f.write("#endif /* _WIN64 */\n\n")
        f.write("C_ASSERT(ARRAYSIZE(__wine_unix_call_funcs) == unix_count);\n")

    def generate_thunks_h(self, f, prefix):
        self._generate_copyright(f)

        f.write("#ifndef __WINE_OPENXR_THUNKS_H\n")
        f.write("#define __WINE_OPENXR_THUNKS_H\n\n")

        f.write("#define WINE_XR_VERSION XR_API_VERSION_{0}_{1}\n\n".format(WINE_XR_VERSION[0], WINE_XR_VERSION[1]))

        # Generate prototypes for device and instance functions requiring a custom implementation.
        f.write("/* Functions for which we have custom implementations outside of the thunks. */\n")
        for xr_func in self.registry.funcs.values():
            if not xr_func.needs_private_thunk():
                continue

            f.write("{0};\n".format(xr_func.prototype(prefix=prefix, is_thunk=True)))
        f.write("\n")

        f.write("/* For use by xrInstance and children */\n")
        f.write("struct openxr_instance_funcs\n{\n")
        for xr_func in self.registry.instance_funcs:
            if not xr_func.needs_exposing():
                continue

            if not xr_func.needs_dispatch():
                LOGGER.debug("skipping {0} in openxr_instance_funcs".format(xr_func.name))
                continue

            f.write("    {0};\n".format(xr_func.pfn()))
        f.write("};\n\n")

        f.write("#define ALL_XR_INSTANCE_FUNCS() \\\n")
        first = True
        for xr_func in self.registry.instance_funcs:
            if not xr_func.needs_exposing():
                continue

            if not xr_func.needs_dispatch():
                LOGGER.debug("skipping {0} in ALL_XR_INSTANCE_FUNCS".format(xr_func.name))
                continue

            if first:
                f.write("    USE_XR_FUNC({0})".format(xr_func.name))
                first = False
            else:
                f.write(" \\\n    USE_XR_FUNC({0})".format(xr_func.name))
        f.write("\n\n")

        f.write("#endif /* __WINE_OPENXR_THUNKS_H */\n")

    def generate_loader_thunks_c(self, f):
        self._generate_copyright(f)

        f.write("#include \"openxr_loader.h\"\n\n")

        f.write("WINE_DEFAULT_DEBUG_CHANNEL(openxr);\n\n")

        for xr_func in self.registry.funcs.values():
            if not xr_func.needs_exposing():
                continue
            if xr_func.name in MANUAL_LOADER_THUNKS | MANUAL_LOADER_FUNCTIONS:
                continue

            f.write(xr_func.loader_thunk())

        f.write("static const struct openxr_func xr_instance_dispatch_table[] =\n{\n")
        for xr_func in self.registry.instance_funcs:
            if not xr_func.needs_exposing():
                continue

            f.write("    {{\"{0}\", {0}}},\n".format(xr_func.name))
        f.write("};\n\n")

        f.write("void *wine_xr_get_instance_proc_addr(const char *name)\n")
        f.write("{\n")
        f.write("    unsigned int i;\n")
        f.write("    for (i = 0; i < ARRAY_SIZE(xr_instance_dispatch_table); i++)\n")
        f.write("    {\n")
        f.write("        if (strcmp(xr_instance_dispatch_table[i].name, name) == 0)\n")
        f.write("        {\n")
        f.write("            TRACE(\"Found name=%s in instance table\\n\", debugstr_a(name));\n")
        f.write("            return xr_instance_dispatch_table[i].func;\n")
        f.write("        }\n")
        f.write("    }\n")
        f.write("    return NULL;\n")
        f.write("}\n")

    def generate_loader_thunks_h(self, f):
        self._generate_copyright(f)

        f.write("#ifndef __WINE_OPENXR_LOADER_THUNKS_H\n")
        f.write("#define __WINE_OPENXR_LOADER_THUNKS_H\n\n")

        f.write("enum unix_call\n")
        f.write("{\n")
        f.write("    unix_init,\n")
        for xr_func in self.registry.funcs.values():
            if not xr_func.needs_exposing():
                continue
            if xr_func.name in MANUAL_LOADER_FUNCTIONS:
                continue

            f.write("    unix_{0},\n".format(xr_func.name))
        f.write("    unix_count,\n")
        f.write("};\n\n")

        for xr_func in self.registry.funcs.values():
            if not xr_func.needs_exposing():
                continue
            if xr_func.name in MANUAL_LOADER_FUNCTIONS:
                continue

            f.write("struct {0}_params\n".format(xr_func.name))
            f.write("{\n");
            extra_param_is_new = True
            for p in xr_func.params:
                f.write("    {0};\n".format(p.definition(is_member=True)))
                if p.name == xr_func.extra_param:
                    extra_param_is_new = False

            if xr_func.extra_param and extra_param_is_new:
                f.write("    void *{0};\n".format(xr_func.extra_param))
            if xr_func.type != "void":
                f.write("    {0} result;\n".format(xr_func.type))
            f.write("};\n\n");

        f.write("#endif /* __WINE_OPENXR_LOADER_THUNKS_H */\n")

    def generate_openxr_h(self, f):
        self._generate_copyright(f)
        f.write("#ifndef __WINE_OPENXR_H\n")
        f.write("#define __WINE_OPENXR_H\n\n")

        f.write("#include <windef.h>\n")
        f.write("#include <stdint.h>\n\n")

        f.write("/* Define WINE_XR_HOST to get 'host' headers. */\n")
        f.write("#ifdef WINE_XR_HOST\n")
        f.write("#define XRAPI_CALL\n")
        f.write('#define WINE_XR_ALIGN(x)\n')
        f.write("#endif\n\n")

        f.write("#ifndef XRAPI_CALL\n")
        f.write("#define XRAPI_CALL __stdcall\n")
        f.write("#endif\n\n")

        f.write("#ifndef XRAPI_PTR\n")
        f.write("#define XRAPI_PTR XRAPI_CALL\n")
        f.write("#endif\n\n")

        f.write("#ifndef WINE_XR_ALIGN\n")
        f.write("#define WINE_XR_ALIGN DECLSPEC_ALIGN\n")
        f.write("#endif\n\n")

        f.write("#if defined(__x86_64__) || defined(__aarch64__)\n")
        f.write("#define XR_PTR_SIZE 8\n")
        f.write("#endif\n\n")

        # The overall strategy is to define independent constants and datatypes,
        # prior to complex structures and function calls to avoid forward declarations.
        for const in self.registry.consts:
            # For now just generate things we may not need. The amount of parsing needed
            # to get some of the info is tricky as you need to figure out which structure
            # references a certain constant.
            f.write(const.definition())
        f.write("\n")

        for define in self.registry.defines:
            f.write(define.definition())

        for handle in self.registry.handles:
            # For backward compatibility also create definitions for aliases.
            # These types normally don't get pulled in as we use the new types
            # even in legacy functions if they are aliases.
            if handle.is_required() or handle.is_alias():
                 f.write(handle.definition())
        f.write("\n")

        for base_type in self.registry.base_types:
            f.write(base_type.definition())
        f.write("\n")

        for bitmask in self.registry.bitmasks:
            f.write(bitmask.definition())
        f.write("\n")

        # Define enums, this includes values for some of the bitmask types as well.
        for enum in self.registry.enums.values():
            if enum.required:
                f.write(enum.definition())

        for fp in self.registry.funcpointers:
            if fp.required:
                f.write(fp.definition())
        f.write("\n")

        # This generates both structures and unions. Since structures
        # may depend on other structures/unions, we need a list of
        # decoupled structs.
        # Note: unions are stored in structs for dependency reasons,
        # see comment in parsing section.
        structs = XrStruct.decouple_structs(self.registry.structs)
        for struct in structs:
            LOGGER.debug("Generating struct: {0}".format(struct.name))
            f.write(struct.definition(align=True))
            f.write("\n")

        for func in self.registry.funcs.values():
            if not func.is_required():
                LOGGER.debug("Skipping PFN definition for: {0}".format(func.name))
                continue

            f.write("typedef {0};\n".format(func.pfn(prefix="PFN", call_conv="XRAPI_PTR")))
        f.write("\n")

        f.write("#ifndef XR_NO_PROTOTYPES\n")
        for func in self.registry.funcs.values():
            if not func.is_required():
                LOGGER.debug("Skipping API definition for: {0}".format(func.name))
                continue

            LOGGER.debug("Generating API definition for: {0}".format(func.name))
            f.write("{0};\n".format(func.prototype(call_conv="XRAPI_CALL")))
        f.write("#endif /* XR_NO_PROTOTYPES */\n\n")

        f.write("#endif /* __WINE_OPENXR_H */\n")

class XrRegistry(object):
    def __init__(self, reg_filename):
        # Used for storage of type information.
        self.base_types = None
        self.bitmasks = None
        self.consts = None
        self.defines = None
        self.enums = None
        self.funcpointers = None
        self.handles = None
        self.structs = None

        # We aggregate all types in here for cross-referencing.
        self.funcs = {}
        self.types = {}

        self.version_regex = re.compile(
            r'^'
            r'XR_VERSION_'
            r'(?P<major>[0-9])'
            r'_'
            r'(?P<minor>[0-9])'
            r'$'
        )

        # Overall strategy for parsing the registry is to first
        # parse all type / function definitions. Then parse
        # features and extensions to decide which types / functions
        # to actually 'pull in' for code generation. For each type or
        # function call we want we set a member 'required' to True.
        tree = ET.parse(reg_filename)
        root = tree.getroot()

        self._parse_enums(root)
        self._parse_types(root)
        self._parse_commands(root)

        # Pull in any required types and functions.
        self._parse_features(root)
        self._parse_extensions(root)

        for enum in self.enums.values():
            enum.fixup_64bit_aliases()

        self._match_object_types()

        self.copyright = root.find('./comment').text

    def _is_feature_supported(self, feature):
        version = self.version_regex.match(feature)
        if not version:
            return True

        version = tuple(map(int, version.group('major', 'minor')))
        return version <= WINE_XR_VERSION

    def _is_extension_supported(self, extension):
        # We disable some extensions as either we haven't implemented
        # support yet or because they are for platforms other than win32.
        return extension not in UNSUPPORTED_EXTENSIONS

    def _mark_type_required(self, type_info):
        """ Helper function to mark a certain types and the datatypes they needs as required."""
        def mark_bitmask_dependencies(bitmask, types):
            if bitmask.requires is not None:
                self._mark_type_required(types[bitmask.requires])

        def mark_funcpointer_dependencies(fp, types):
            for m in fp.members:
                self._mark_type_required(types[m.type])

        def mark_struct_dependencies(struct, types):
             for m in struct:
                type_info = types[m.type]

                if struct.name != m.type:
                    self._mark_type_required(types[m.type])

        # Check if we are dealing with a complex type e.g. XrEnum, XrStruct and others.
        if "data" not in type_info:
            return

        # Mark the complex type as required.
        type_info["data"].required = True
        if type_info["category"] == "struct":
            mark_struct_dependencies(type_info["data"], self.types)
        elif type_info["category"] == "funcpointer":
            mark_funcpointer_dependencies(type_info["data"], self.types)
        elif type_info["category"] == "bitmask":
            mark_bitmask_dependencies(type_info["data"], self.types)

    def _mark_command_required(self, command):
        """ Helper function to mark a certain command and the datatypes it needs as required."""
        func = self.funcs[command]
        func.required = True

        # Pull in return type
        if func.type != "void":
            self.types[func.type]["data"].required = True

        # Analyze parameter dependencies and pull in any type needed.
        for p in func.params:
            self._mark_type_required(self.types[p.type])

    def _match_object_types(self):
        """ Matches each handle with the correct object type. """
        # Use upper case comparison for simplicity.
        object_types = {}
        for value in self.enums["XrObjectType"].values:
            object_name = "XR" + value.name[len("XR_OBJECT_TYPE"):].replace("_", "")
            object_types[object_name] = value.name

        for handle in self.handles:
            if not handle.is_required():
                continue
            handle.object_type = object_types.get(handle.name.upper())
            if not handle.object_type:
                LOGGER.warning("No object type found for {}".format(handle.name))

    def _parse_commands(self, root):
        """ Parse command section containing the Vulkan function calls. """
        funcs = {}
        commands = root.findall("./commands/")

        # As of Vulkan 1.1, various extensions got promoted to Core.
        # The old commands (e.g. KHR) are available for backwards compatibility
        # and are marked in xr.xml as 'alias' to the non-extension type.
        # The registry likes to avoid data duplication, so parameters and other
        # metadata need to be looked up from the Core command.
        # We parse the alias commands in a second pass.
        alias_commands = []
        for command in commands:
            alias_name = command.attrib.get("alias")
            if alias_name:
                alias_commands.append(command)
                continue

            func = XrFunction.from_xml(command, self.types)

            if func:
                funcs[func.name] = func

        for command in alias_commands:
            alias_name = command.attrib.get("alias")
            alias = funcs[alias_name]
            func = XrFunction.from_alias(command, alias)
            if func:
                funcs[func.name] = func

        # To make life easy for the code generation, separate all function
        # calls out in the 4 types of Vulkan functions:
        # device, global, physical device and instance.
        instance_funcs = []
        for func in funcs.values():
            if not func.name in NOT_OUR_FUNCTIONS:
                instance_funcs.append(func)

        # Sort function lists by name and store them.
        self.instance_funcs = sorted(instance_funcs, key=lambda func: func.name)

        # The funcs dictionary is used as a convenient way to lookup function
        # calls when needed e.g. to adjust member variables.
        self.funcs = OrderedDict(sorted(funcs.items()))

    def _parse_enums(self, root):
        """ Parse enums section or better described as constants section. """
        enums = {}
        self.consts = []
        for enum in root.findall("./enums"):
            name = enum.attrib.get("name")
            _type = enum.attrib.get("type")

            if _type in ("enum", "bitmask"):
                enum_obj = XrEnum.from_xml(enum)
                if enum_obj:
                    enums[name] = enum_obj
            else:
                # If no type is set, we are dealing with API constants.
                for value in enum.findall("enum"):
                    # If enum is an alias, set the value to the alias name.
                    # E.g. XR_LUID_SIZE_KHR is an alias to XR_LUID_SIZE.
                    alias = value.attrib.get("alias")
                    if alias:
                        self.consts.append(XrConstant(value.attrib.get("name"), alias))
                    else:
                        self.consts.append(XrConstant(value.attrib.get("name"), value.attrib.get("value")))

        self.enums = OrderedDict(sorted(enums.items()))

    def _process_require_enum(self, enum_elem, ext=None, only_aliased=False):
        if "extends" in enum_elem.keys():
            enum = self.types[enum_elem.attrib["extends"]]["data"]

            # Need to define XrEnumValues which were aliased to by another value. This is necessary
            # from XR spec version 1.2.135 where the provisional XR_KHR_ray_tracing extension was
            # added which altered XR_NV_ray_tracing's XrEnumValues to alias to the provisional
            # extension.
            aliased = False
            for _, t in self.types.items():
                if t["category"] != "enum":
                    continue
                if not t["data"]:
                    continue
                for value in t["data"].values:
                    if value.alias == enum_elem.attrib["name"]:
                        aliased = True

            if only_aliased and not aliased:
                return

            if "bitpos" in enum_elem.keys():
                # We need to add an extra value to an existing enum type.
                # E.g. XR_FORMAT_FEATURE_SAMPLED_IMAGE_FILTER_CUBIC_BIT_IMG to XrFormatFeatureFlagBits.
                enum.create_bitpos(enum_elem.attrib["name"], int(enum_elem.attrib["bitpos"]))

            elif "offset" in enum_elem.keys():
                # Extensions promoted to Core, have the extension number as part
                # of the enum value. Else retrieve from the extension tag.
                if enum_elem.attrib.get("extnumber"):
                    ext_number = int(enum_elem.attrib.get("extnumber"))
                else:
                    ext_number = int(ext.attrib["number"])
                offset = int(enum_elem.attrib["offset"])
                value = EXT_BASE + (ext_number - 1) * EXT_BLOCK_SIZE + offset

                # Deal with negative values.
                direction = enum_elem.attrib.get("dir")
                if direction is not None:
                    value = -value

                enum.create_value(enum_elem.attrib["name"], str(value))

            elif "value" in enum_elem.keys():
                enum.create_value(enum_elem.attrib["name"], enum_elem.attrib["value"])
            elif "alias" in enum_elem.keys():
                enum.create_alias(enum_elem.attrib["name"], enum_elem.attrib["alias"])

        elif "value" in enum_elem.keys():
            # Constant with an explicit value
            if only_aliased:
                return

            self.consts.append(XrConstant(enum_elem.attrib["name"], enum_elem.attrib["value"]))
        elif "alias" in enum_elem.keys():
            # Aliased constant
            if not only_aliased:
                return

            self.consts.append(XrConstant(enum_elem.attrib["name"], enum_elem.attrib["alias"]))

    @staticmethod
    def _require_type(type_info):
        if type_info.is_alias():
            type_info = type_info.alias

        type_info.required = True
        if type(type_info) == XrStruct:
            for member in type_info.members:
                if "data" in member.type_info:
                  XrRegistry._require_type(member.type_info["data"])

    def _parse_extensions(self, root):
        """ Parse extensions section and pull in any types and commands for this extension. """
        extensions = []
        exts = root.findall("./extensions/extension")
        deferred_exts = []
        skipped_exts = UNSUPPORTED_EXTENSIONS.copy()

        def process_ext(ext, deferred=False):
            ext_name = ext.attrib["name"]

            # Set extension name on any functions calls part of this extension as we
            # were not aware of the name during initial parsing.
            commands = ext.findall("require/command")
            for command in commands:
                cmd_name = command.attrib["name"]
                # Need to verify that the command is defined, and otherwise skip it.
                # xrCreateScreenSurfaceQNX is declared in <extensions> but not defined in
                # <commands>. A command without a definition cannot be enabled, so it's valid for
                # the XML file to handle this, but because of the manner in which we parse the XML
                # file we pre-populate from <commands> before we check if a command is enabled.
                if cmd_name in self.funcs:
                    self.funcs[cmd_name].extensions.add(ext_name)

            # Some extensions are not ready or have numbers reserved as a place holder
            # or are only supported for VulkanSC.
            if not "openxr" in ext.attrib["supported"].split(","):
                LOGGER.debug("Skipping disabled extension: {0}".format(ext_name))
                skipped_exts.append(ext_name)
                return

            protect = ext.attrib.get("protect", None)
            if not protect is None and \
                    not protect in ALLOWED_PROTECTS:
                return

            # Defer extensions with 'sortorder' as they are order-dependent for spec-parsing.
            if not deferred and "sortorder" in ext.attrib:
                deferred_exts.append(ext)
                return

            # Disable highly experimental extensions as the APIs are unstable and can
            # change between minor Vulkan revisions until API is final and becomes KHR
            # or NV.
            if ("KHX" in ext_name or "NVX" in ext_name) and ext_name not in ALLOWED_X_EXTENSIONS:
                LOGGER.debug("Skipping experimental extension: {0}".format(ext_name))
                skipped_exts.append(ext_name)
                return

            # Extensions can define XrEnumValues which alias to provisional extensions. Pre-process
            # extensions to define any required XrEnumValues before the platform check below.
            for require in ext.findall("require"):
                # Extensions can add enum values to Core / extension enums, so add these.
                for enum_elem in require.findall("enum"):
                    self._process_require_enum(enum_elem, ext, only_aliased=True)

            platform = ext.attrib.get("platform")
            if platform and platform != "win32":
                LOGGER.debug("Skipping extensions {0} for platform {1}".format(ext_name, platform))
                skipped_exts.append(ext_name)
                return

            if not self._is_extension_supported(ext_name):
                LOGGER.debug("Skipping unsupported extension: {0}".format(ext_name))
                skipped_exts.append(ext_name)
                return
            elif "requires" in ext.attrib:
                # Check if this extension builds on top of another unsupported extension.
                requires = ext.attrib["requires"].split(",")
                if len(set(requires).intersection(skipped_exts)) > 0:
                    skipped_exts.append(ext_name)
                    return
            elif "depends" in ext.attrib:
                # The syntax for this is more complex, but this is good enough for now.
                if any([sext in ext.attrib["depends"] for sext in skipped_exts]):
                    skipped_exts.append(ext_name)
                    return

            LOGGER.debug("Loading extension: {0}".format(ext_name))

            # Extensions can define one or more require sections each requiring
            # different features (e.g. Vulkan 1.1). Parse each require section
            # separately, so we can skip sections we don't want.
            for require in ext.findall("require"):
                # Extensions can add enum values to Core / extension enums, so add these.
                for enum_elem in require.findall("enum"):
                    self._process_require_enum(enum_elem, ext)

                for t in require.findall("type"):
                    if not t.attrib["name"] in self.types:
                        continue

                    type_info = self.types[t.attrib["name"]]["data"]
                    self._require_type(type_info)
                feature = require.attrib.get("feature")
                if feature and not self._is_feature_supported(feature):
                    continue

                required_extension = require.attrib.get("extension")
                if required_extension and not self._is_extension_supported(required_extension):
                    continue

                # Pull in any commands we need. We infer types to pull in from the command
                # as well.
                for command in require.findall("command"):
                    cmd_name = command.attrib["name"]
                    self._mark_command_required(cmd_name)


            # Store a list with extensions.
            ext_info = {"name" : ext_name, "type" : ext.attrib["type"]}
            extensions.append(ext_info)


        # Process extensions, allowing for sortorder to defer extension processing
        for ext in exts:
            process_ext(ext)

        deferred_exts.sort(key=lambda ext: ext.attrib["sortorder"])

        # Respect sortorder
        for ext in deferred_exts:
            process_ext(ext, deferred=True)

        # Sort in alphabetical order.
        self.extensions = sorted(extensions, key=lambda ext: ext["name"])

    def _parse_features(self, root):
        """ Parse the feature section, which describes Core commands and types needed. """

        for feature in root.findall("./feature"):
            if not api_is_openxr(feature):
                continue
            feature_name = feature.attrib["name"]
            for require in feature.findall("require"):
                LOGGER.info("Including features for {0}".format(require.attrib.get("comment")))
                for tag in require:
                    if tag.tag == "comment":
                        continue
                    elif tag.tag == "command":
                        if not self._is_feature_supported(feature_name):
                            continue
                        name = tag.attrib["name"]
                        self._mark_command_required(name)
                    elif tag.tag == "enum":
                        self._process_require_enum(tag)
                    elif tag.tag == "type":
                        name = tag.attrib["name"]

                        # Skip pull in for openxr_platform_defines.h for now.
                        if name == "openxr_platform_defines":
                            continue

                        self._mark_type_required(self.types[name])

    def _parse_types(self, root):
        """ Parse types section, which contains all data types e.g. structs, typedefs etcetera. """
        types = root.findall("./types/type")

        base_types = []
        bitmasks = []
        defines = []
        funcpointers = []
        handles = []
        structs = []

        alias_types = []
        for t in types:
            type_info = {}
            type_info["category"] = t.attrib.get("category", None)
            type_info["requires"] = t.attrib.get("requires", None)

            # We parse aliases in a second pass when we know more.
            alias = t.attrib.get("alias")
            if alias:
                LOGGER.debug("Alias found: {0}".format(alias))
                alias_types.append(t)
                continue

            protect = t.attrib.get("protect", None)
            if not protect is None and \
                    not protect in ALLOWED_PROTECTS:
                continue

            if type_info["category"] in ["include"]:
                continue

            if type_info["category"] == "basetype":
                name = t.find("name").text
                _type = None
                if not t.find("type") is None:
                    _type = t.find("type").text
#                    tail = t.find("type").tail
#                    if tail is not None:
#                        _type += tail.strip()

                if (_type == "XR_DEFINE_ATOM"):
                    _type = "uint64_t"
                elif (_type == "XR_DEFINE_OPAQUE_64"):
                    _type = "struct " + name + "_T*"

                basetype = XrBaseType(name, _type)
                if basetype:
                    base_types.append(basetype)
                    type_info["data"] = basetype
                else:
                    continue

            # Basic C types don't need us to define them, but we do need data for them
            if type_info["requires"] == "xr_platform":
                requires = type_info["requires"]
                basic_c = XrBaseType(name, _type, requires=requires)
                type_info["data"] = basic_c

            if type_info["category"] == "bitmask":
                name = t.find("name").text
                _type = t.find("type").text

                # Most bitmasks have a bitvalues attribute used to pull in
                # required '*FlagBits" enum.
                type_info["requires"] = t.attrib.get("bitvalues", None)

                requires = type_info["requires"]
                bitmask = XrBaseType(name, _type, requires=requires)
                bitmasks.append(bitmask)
                type_info["data"] = bitmask

            if type_info["category"] == "define":
                define = XrDefine.from_xml(t)
                if define:
                    defines.append(define)
                    type_info["data"] = define
                else:
                    continue

            if type_info["category"] == "enum":
                name = t.attrib.get("name")
                # The type section only contains enum names, not the actual definition.
                # Since we already parsed the enum before, just link it in.
                try:
                    type_info["data"] = self.enums[name]
                except KeyError:
                    # Not all enums seem to be defined yet, typically that's for
                    # ones ending in 'FlagBits' where future extensions may add
                    # definitions.
                    type_info["data"] = None

            if type_info["category"] == "funcpointer":
                funcpointer = XrFunctionPointer.from_xml(t)
                if funcpointer:
                    funcpointers.append(funcpointer)
                    type_info["data"] = funcpointer
                else:
                    continue

            if type_info["category"] == "handle":
                handle = XrHandle.from_xml(t)
                if handle:
                    handles.append(handle)
                    type_info["data"] = handle
                else:
                    continue

            if type_info["category"] in ["struct", "union"]:
                # We store unions among structs as some structs depend
                # on unions. The types are very similar in parsing and
                # generation anyway. The official Vulkan scripts use
                # a similar kind of hack.
                struct = XrStruct.from_xml(t)
                if struct:
                    structs.append(struct)
                    type_info["data"] = struct
                else:
                    continue

            # Name is in general within a name tag else it is an optional
            # attribute on the type tag.
            name_elem = t.find("name")
            if name_elem is not None:
                type_info["name"] = name_elem.text
            else:
                type_info["name"] = t.attrib.get("name", None)

            # Store all type data in a shared dictionary, so we can easily
            # look up information for a given type. There are no duplicate
            # names.
            self.types[type_info["name"]] = type_info

        # Second pass for alias types, so we can retrieve all data from
        # the aliased object.
        for t in alias_types:
            type_info = {}
            type_info["category"] = t.attrib.get("category")
            type_info["name"] = t.attrib.get("name")

            alias = t.attrib.get("alias")

            if type_info["category"] == "bitmask":
                bitmask = XrBaseType(type_info["name"], alias, alias=self.types[alias]["data"])
                bitmasks.append(bitmask)
                type_info["data"] = bitmask

            if type_info["category"] == "enum":
                enum = XrEnum.from_alias(t, self.types[alias]["data"])
                type_info["data"] = enum
                self.enums[enum.name] = enum

            if type_info["category"] == "handle":
                handle = XrHandle.from_alias(t, self.types[alias]["data"])
                handles.append(handle)
                type_info["data"] = handle

            if type_info["category"] == "struct":
                struct = XrStruct.from_alias(t, self.types[alias]["data"])
                structs.append(struct)
                type_info["data"] = struct

            self.types[type_info["name"]] = type_info

        # We need detailed type information during code generation
        # on structs for alignment reasons. Unfortunately structs
        # are parsed among other types, so there is no guarantee
        # that any types needed have been parsed already, so set
        # the data now.
        for struct in structs:
            struct.set_type_info(self.types)

            # Alias structures have enum values equivalent to those of the
            # structure which they are aliased against. we need to ignore alias
            # structs when populating the struct extensions list, otherwise we
            # will create duplicate case entries.
            if struct.alias:
                continue

            for structextend in struct.structextends:
                s = self.types[structextend]["data"]
                s.struct_extensions.append(struct)

        # Guarantee everything is sorted, so code generation doesn't have
        # to deal with this.
        self.base_types = sorted(base_types, key=lambda base_type: base_type.name)
        self.bitmasks = sorted(bitmasks, key=lambda bitmask: bitmask.name)
        self.defines = defines
        self.enums = OrderedDict(sorted(self.enums.items()))
        self.funcpointers = funcpointers
        self.handles = sorted(handles, key=lambda handle: handle.name)
        self.structs = sorted(structs, key=lambda struct: struct.name)

def generate_openxr_json(f):
    f.write("{\n")
    f.write("    \"file_format_version\": \"1.0.0\",\n")
    f.write("    \"ICD\": {\n")
    f.write("        \"library_path\": \".\\\\wineopenxr.dll\",\n")
    f.write("        \"api_version\": \"{0}\"\n".format(XR_XML_VERSION))
    f.write("    }\n")
    f.write("}\n")

def set_working_directory():
    path = os.path.abspath(__file__)
    path = os.path.dirname(path)
    os.chdir(path)

def download_xr_xml(filename):
    url = "https://raw.githubusercontent.com/KhronosGroup/OpenXR-SDK/release-{0}/specification/registry/xr.xml".format(XR_XML_VERSION)
    if not os.path.isfile(filename):
        urllib.request.urlretrieve(url, filename)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-v", "--verbose", action="count", default=0, help="increase output verbosity")
    parser.add_argument("-x", "--xml", default=None, type=str, help="path to specification XML file")

    args = parser.parse_args()
    if args.verbose == 0:
        LOGGER.setLevel(logging.WARNING)
    elif args.verbose == 1:
        LOGGER.setLevel(logging.INFO)
    else: # > 1
        LOGGER.setLevel(logging.DEBUG)

    set_working_directory()

    if args.xml:
        xr_xml = args.xml
    else:
        xr_xml = "xr-{0}.xml".format(XR_XML_VERSION)
        download_xr_xml(xr_xml)

    registry = XrRegistry(xr_xml)
    generator = XrGenerator(registry)

    with open(WINE_OPENXR_H, "w") as f:
        generator.generate_openxr_h(f)

    with open(WINE_OPENXR_THUNKS_H, "w") as f:
        generator.generate_thunks_h(f, "wine_")

    with open(WINE_OPENXR_THUNKS_C, "w") as f:
        generator.generate_thunks_c(f)

    with open(WINE_OPENXR_LOADER_THUNKS_H, "w") as f:
        generator.generate_loader_thunks_h(f)

    with open(WINE_OPENXR_LOADER_THUNKS_C, "w") as f:
        generator.generate_loader_thunks_c(f)

    with open(WINE_OPENXR_JSON, "w") as f:
        generate_openxr_json(f)

if __name__ == "__main__":
    main()
